Braniac
Massively distributed training at scale.
Chief idea is to either: add fast-interconnect islands to DiLoCo, in order to train foundational models over scattered shards of compute, of the scale of a commercial mobile phone.
My target roughly corresponds to orchestrating a 1B KellerJordan/modded-nanogpt
run over a minimal cluster of 6 GB RAM devices. These numbers are
for the simple reason that a 1B model fits in inference-mode in a
6 GB RAM device, with some RAM to spare for the operating system.
(Exact calculations here).
DiLoCo, which stands for Distributed Low Communication training, is a method for training an LLM with very low comms requirements, while remaining nearly optimal in FLOPs used to reach a certain perplexity. It has a few implementations across the internet, and is part of a few other notable methods that have recently arisen to tackle the problem of training models over the internet. See: [Psyche Network], [DeMo], [DisTro], [Prime-Intellect].
For scalable training, a lot needs to be optimized for. The right mix should be somewhere between Model Parallelism + DisTro + Modula (Bernstein et al). Model parallelism needs to be figured out; see: DiPaCo,
On Optimization
Talk by Jeremy Bernstein on Optimization in the Modular norm. Talks about splitting optimizers by assigning different metrics/dis-funcs to different layers (or “blocks”). Split into atoms, bonds, compounds; see more in Bernstein’s Cohere talk.
modula, library for metrized Deep Learning (from OpenAI, Thinking Machines).
On Distributed Training
- Bowen
Peng on Distributed Training; primarily, critique on diloco, in that the communication
reduction is in the amortized case, i.e, once
every H calls, there needs to be a huge transfer between GPUs, and
in that time if something goes off, it stops working.
- This is fair, this disqualifies diloco from our consideration as well.
- Instead, proposes distro, which is designed for data-parallel over heterogeneous compute – which is what we want, minus the model parallelism.
- Hence, inspect work by Thinking Machines, and by
- DiPaCo is something entirely different. Should keep an eye.
Personal ToDos:
- Seems like a nice forcing function to learn rust
MatFormer / MatLM / MatViT — Deep Technical Interview Q&A Bank
This file converts the earlier question bank into a Q/A format.
Caveat (honest + practical): Some “code-reading” questions depend on exact repository implementation details (e.g., precise config schema, internal module names, caching hooks). Where those file-level details are not reliably visible from the GitHub web UI in this environment, answers give the most likely/standard implementation consistent with the paper and the repo READMEs, plus a “verify in code” note describing what to check.
0) Quick calibration prompts (5–10 min)
Q: In one minute: what problem does
MatFormer solve, and why is “train 4 sizes” not good enough?
A: MatFormer trains one set of
weights that supports many compute/quality
points by nesting submodels (primarily via FFN width).
Training four separate sizes is (i) 4× training
cost, (ii) yields inconsistent
behaviours (distributional drift, different logits), and
(iii) can’t cheaply produce exponentially many
intermediate models or allow per-layer mix-and-match. MatFormer
aims for elastic inference (dynamic sizing), with
high behavioural consistency because the models
share weights.
Q: What does “matryoshka” mean in weight
space vs “matryoshka representation learning” in
activation/output space?
A: In MatFormer, “matryoshka” means
nested parameter subspaces: smaller models are
prefix slices of larger ones (e.g., first (h/2)
FFN neurons). In “matryoshka representation learning” (MRL), the
nesting is usually enforced on representations
(e.g., first (d’) dims of an embedding are useful), not
necessarily via shared weights. Weight nesting implies
shared training dynamics and stronger
consistency; representation nesting can be added as an auxiliary
constraint but doesn’t automatically give weight-sharing.
Q: If you could only make one
transformer component elastic (embeddings / attention / FFN), why
does the paper focus on FFN? Quantify the share
of params/latency.
A: In standard Transformers (GPT-like),
FFNs dominate parameters for typical
configurations because FFN params scale as (pprox 2 d h) per layer
(often (h pprox 4d)), while attention projections scale as (pprox
4 d^2). With (h=4d), FFN params (pprox 8d^2) vs attention (pprox
4d^2) (roughly ~2× more in FFN than attention,
per layer). Latency often also has a large FFN contribution
(especially for moderate context lengths), so shrinking FFN offers
a big compute/parameter lever without changing
sequence-length–dependent attention complexity.
Q: Define “behavioral
consistency” between two generative models in a way you
can actually measure.
A: A measurable definition: for matched prompts
and decoding settings, compare (i) distributional
consistency such as average per-token KL
((p_{small}|p_{large})) or cross-entropy of large under small, and
(ii) trajectory consistency such as exact
token match rate under greedy decoding or matched
sampling seeds. Also track accept rate in
speculative decoding as an operational proxy.
1) Core architecture: the nested FFN block
A. Formal definition & invariances
Q: Write the MatFormer FFN for granularity
(g): which rows/cols of (W_1, W_2) are used, and what is the
“first neurons are most significant” claim?
A: Let hidden size for granularity (g) be (h_g)
with (h_g h). For input (x^d), a vanilla FFN is \[\text{FFN}(x)=W_2\,\sigma(W_1 x +
b_1)+b_2\]with (W_1^{hd}), (W_2^{dh}).
MatFormer uses the prefix of the hidden
dimension: - (W_{1,g} = W_1[:h_g, :]), (b_{1,g}=b_1[:h_g]) -
(W_{2,g} = W_2[:, :h_g]), (b_2) unchanged
So\[
\text{FFN}_g(x)=W_2[:, :h_g]\,\sigma(W_1[:h_g,:]x+b_1[:h_g]) + b_2
\]“First neurons are most significant” is not true
a priori (hidden units are permutation-symmetric), but
MatFormer’s joint training breaks symmetry by
forcing early units to be used by all submodels,
so they become the most “universally useful” ones.
Q: Explain the permutation
symmetry of FFN hidden units in a vanilla transformer.
How does MatFormer break it, and why doesn’t training
just reintroduce arbitrary permutations?
A: In a vanilla FFN, hidden units can be
permuted: if (P) is a permutation matrix, replacing (W_1PW_1) and
(W_2W_2P^{-1}) leaves the function unchanged (up to corresponding
bias permutation). MatFormer breaks this because submodels must
use the prefix of neurons; a permutation that
moves a “good” neuron out of the prefix would harm smaller
submodels, so the joint loss discourages it. Training won’t
reintroduce arbitrary permutations because the objective
is not permutation-invariant anymore: the prefix
constraint defines an ordering that the optimiser must
respect.
Q: Suppose you initialise MatFormer from a
pretrained dense model: what neuron ordering do you choose and how
(e.g., by norm / Fisher / activation stats / learned reordering)?
What failure modes appear if ordering is bad?
A: Options: - Magnitude / norm
ordering: rank neurons by (|W_2[:,j]|) or combined norm
(|W_1[j,:]||W_2[:,j]|). - Activation stats: run
data through model, rank by mean activation, variance, or
contribution to output norm. - Fisher / saliency:
approximate importance via gradient-based criteria. -
Learned reordering: train a small
permutation/assignment or do fine-tuning with a soft mask then
harden to prefix.
Bad ordering can cause: small submodels underperform (because
important neurons were placed late), unstable joint training
(small loss high → large gradients), and slow convergence until
fine-tuning reassigns importance.
Q: Does nesting induce an implicit
regularisation? If so, what kind (capacity
control, shared features, implicit distillation), and how would
you test it?
A: Yes: it encourages shared,
general-purpose features in early neurons, acting like
(i) capacity sharing (multiple tasks = multiple
widths) and (ii) an implicit distillation where
smaller models’ gradients steer shared parameters toward broader
utility. Test via: - compare overfitting/validation gap vs
baseline - representational similarity of early neurons across
widths - add/remove small-model losses and see generalisation
changes - evaluate robustness/OOD sensitivity vs baseline.
B. Granularities & parameter accounting
Q: Paper + repo mention exponentially
spaced widths (e.g., (h, h/2, h/4, h/8)). Why
exponential, not linear?
A: Exponential spacing gives (i) roughly
constant relative compute ratio between
neighbouring sizes, (ii) wide coverage of the Pareto frontier with
few points, and (iii) aligns with practical deployment tiers (2×,
4×, 8×). Linear spacing would waste points in regimes where
marginal gains are small and increases training overhead.
Q: Given model dim (d) and FFN hidden (h),
derive FFN parameter count and FLOPs; then compute savings for
each granularity.
A: Parameters per FFN (ignoring biases): (#dh +
hd = 2dh). If (h_g = h/k), then (#_g 2d(h/k)= (1/k)(2dh)). FLOPs
per token similarly scale (2dh) for the two matmuls (activation
cost smaller), so savings are roughly proportional to (h_g/h). For
widths (h, h/2, h/4, h/8), savings relative to full are (1, 1/2,
1/4, 1/8) in FFN compute/params.
Q: If you only shrink FFN, when does attention
dominate latency? What does that imply for diminishing
returns of MatFormer at small widths?
A: Attention cost scales roughly as (O(n d^2))
for projections plus (O(n^2 d)) for attention scores (depending on
kernel/fusion), while FFN scales (O(n d h)). As (h) shrinks, FFN
cost drops linearly but attention stays. For sufficiently small
(h_g), attention becomes dominant, so further shrinking FFN yields
diminishing latency gains. Practically: MatFormer’s best latency
leverage is in regimes where FFN is a large slice of runtime
(moderate context length, large (h)).
Q: The paper says nesting could also apply to
attention heads and/or embeddings. Sketch a
consistent nesting scheme for attention that preserves tensor
shapes and doesn’t require reparameterising projections.
A: One scheme: keep model dimension (d) fixed but
nest heads by prefixing head groups. Implement
(W_Q,W_K,W_V^{dd}) but define head count (H_gH) and only use the
first (H_g) heads’ slices in the reshape/split into heads (i.e.,
use first (H_gd_h) channels, where (d_h=d/H)). You must ensure
(d_h) constant and choose (H) so it divides (d). Alternative: nest
per-head MLP projections via block-diagonal
structure, but that’s more invasive.
C. Numerical / implementation details
Q: How do you implement “first (h_g) neurons”
in PyTorch so gradients flow correctly to shared weights? (Views
vs copies; slicing semantics; contiguous memory;
checkpointing.)
A: Use tensor views (slicing) on
parameters in forward: - W1_g = W1[:hg, :],
b1_g = b1[:hg], W2_g = W2[:, :hg]. This
preserves gradient flow into the underlying parameter storage.
Avoid .clone() or .detach(). Ensure
slices are contiguous if needed for kernel performance (use
.contiguous() cautiously—this creates a copy, but
gradients still route back if created via view? In general, prefer
views; if you must make contiguous, verify gradient mapping). For
checkpointing, wrap the FFN forward in
torch.utils.checkpoint.checkpoint at the module level
so autograd recomputes correctly.
Q: What is the most efficient way to compute
outputs for all granularities in a forward pass:
naive multiple FFN passes / compute maximal hidden activations
once + reuse slices / compute incremental contributions
(prefix-sums style). Which is correct, and which is fastest on
GPUs?
A: Correct approaches: - Max hidden once
+ slice: compute (z=(W_1 x)) for full (h), then for each
(g) compute output (W_2[:, :h_g] z[:h_g]). This shares the
expensive first matmul. - Incremental
contributions: partition hidden into blocks (e.g., ranges
corresponding to (h/2, h/4)) and accumulate outputs via block
matmuls; useful if you want all outputs and can reuse partial
sums.
Naive multiple FFN passes repeats (W_1 x) matmuls and is typically
slowest. Fastest depends on kernel fusion and memory bandwidth;
“max once + multiple smaller GEMMs” is often good, while a
well-fused incremental kernel could be best but requires custom
kernels.
Q: What subtle bug appears if you accidentally
let the smaller model use weights that were meant to be exclusive
to a larger granularity (off-by-one / wrong slice axis)? What unit
test catches it?
A: Bug: small model’s output depends on “late”
neurons, violating nesting; you’ll see unexpectedly good/bad small
performance and broken consistency metrics. Unit tests: -
Dependency test: zero out exclusive weights (late
columns/rows) and verify small model output unchanged. -
Gradient test: backprop small loss and confirm
gradients are zero on exclusive weight regions. -
Shape-axis test: assert slicing is on the hidden
dimension, not input/output.
2) Training objective & optimisation
A. Joint loss, weighting, and gradient interference
Q: State the training objective: weighted
average of submodel losses. Why does uniform weighting make sense
as a default, and when would you change it?
A: Objective:\[\mathcal{L}=\sum_{g\in\mathcal{G}}\alpha_g\,\mathcal{L}_g\]where
each (_g) is next-token loss from submodel (g). Uniform weights
(_g=1/||) is a reasonable default because it trains all
granularities comparably and maintains consistency. Change weights
if deployment priorities skew (e.g., most traffic on small model),
or if smallest models collapse (increase their weight early), or
if you want best full-size performance (increase large model
weight).
Q: Show (mathematically or via intuition) how
shared parameters receive gradients from multiple
submodels. When can gradients conflict?
A: For shared parameter (), Q:
Show (mathematically or via intuition) how shared parameters
receive gradients from multiple submodels. When can
gradients conflict?
A: For shared parameter (),\[\nabla_\theta \mathcal{L}=\sum_g
\alpha_g\nabla_\theta \mathcal{L}_g\]Conflict occurs when
gradient directions disagree (negative cosine similarity).hasise
small models vs late emphasise big model—what do you expect? How
would you decide?
A: One schedule: early increase weight on
small widths to force robust “core” features;
later shift weight toward largest to maximise
final quality. Expect better small-model performance and possibly
better generalisation; but too much early small weighting may cap
large-model peak. Decide by monitoring (i) per-granularity losses,
(ii) consistency metrics, and (iii) final downstream evaluation;
use ablations.
Q: What is the analogue of “sandwich
rule” from slimmable nets here? Does MatFormer implicitly
rely on it?
A: Sandwich rule typically trains smallest +
largest + a few random intermediate widths per step to cover width
space efficiently. MatFormer can use an analogous strategy:
compute losses for extremes and a sampled subset of widths rather
than all. The paper’s setup often uses a fixed discrete set; but
sampling is a plausible efficiency extension.
Q: If you observe the smallest submodel
collapsing (bad loss) while the largest is fine, what are your
first three interventions?
A: (1) Increase (_{small}) and/or apply a
curriculum that prioritises small early. (2) Improve neuron
ordering / enforce stronger nesting (e.g., explicit regulariser
that encourages early neurons to carry more signal). (3) Reduce
learning rate or use per-granularity auxiliary distillation (small
logits distil from large), stabilising gradients.
B. Compute & systems efficiency
Q: The paper notes joint training uses one
forward per submodel and benefits from shared computation. Where
exactly is compute shared, and where is it not?
A: Shared: embeddings, attention blocks, layer
norms, residual stream up to FFN computation, and (if implemented)
the first FFN matmul (W_1 x) can be shared by
computing full hidden once. Not shared: the final projection
(W_2[:, :h_g] z[:h_g]) per granularity, plus per-granularity loss
computations; backward graphs also differ per slice.
Q: What is the wall-clock
overhead vs training only the largest model? What
dominates it (extra matmuls, extra softmax, extra backward
graph)?
A: Overhead comes mainly from additional FFN
output matmuls and additional backward contributions, plus extra
loss computations. If implemented naively, repeated (W_1 x)
matmuls dominate overhead; if shared, overhead is closer to “extra
(W_2) matmuls + extra backward ops.” Softmax/loss overhead is
small relative to matmuls.
Q: If you were to implement a fused MatFormer
FFN kernel, what would the kernel signature look like? How would
you expose it cleanly to autograd?
A: Signature could accept
(x, W1, b1, W2, b2, widths) and return outputs for
selected widths or a packed tensor. Internally it computes full
hidden once and produces multiple outputs with minimal memory
traffic. Expose via a custom torch.autograd.Function
with saved tensors (or recomputation) and implement backward that
accumulates gradients properly for shared regions.
Q: How does gradient checkpointing interact
with multiple submodel forwards? What’s the right checkpoint
boundary?
A: Checkpointing reduces activation memory by
recomputing forward during backward. With multiple submodel
losses, checkpoint boundaries should wrap the entire block
whose activations are otherwise replicated, e.g., the
transformer layer including FFN computation. If computing full
hidden once, checkpoint at the layer boundary and recompute shared
hidden during backward for each loss, or store shared hidden once
if memory allows. You must ensure recomputation is
deterministic.
C. Inducing structure by finetuning
Q: The paper says you can induce MatFormer
structure via finetuning (not only pretraining). What assumptions
are required for this to work?
A: Assumptions: the pretrained model has
redundant capacity and can reorganise features under the new
constraint; the task/data used for finetuning is sufficient to
shape neuron importance; training is long enough and
well-conditioned. It works better if you start with a reasonable
neuron ordering and avoid catastrophic forgetting.
Q: If you finetune a pretrained dense model
into MatFormer, what do you do with “exclusive” neurons in larger
widths—initialise from existing weights or random? Why?
A: Prefer initialising from existing weights
(same indices) because it preserves learned features and speeds
convergence. If you must change dimensionality/order, use a
deterministic mapping plus small noise. Random init risks
destabilising large model and harming consistency; it may be
acceptable if exclusive regions are truly new capacity.
3) Mix’n’Match: exponential model extraction
A. Combinatorics & feasibility
Q: If you have (G) granularities and (L)
layers, how many Mix’n’Match models exist? When is that number
meaningful vs purely theoretical?
A: (G^L) possible per-layer assignments. It’s
meaningful when you can (i) evaluate or select among them
efficiently (DP/greedy/controller), and (ii) the extracted models
are “good” without per-model training. Otherwise it’s mostly
theoretical combinatorics.
Q: The paper claims Mix’n’Match models (never
explicitly optimised) still lie on the accuracy–compute curve
traced by trained granularities. Why might this generalise? When
might it fail?
A: It may generalise because each layer’s prefix
weights are trained under multiple widths, yielding modular
“cores” that compose. It can fail if there are strong inter-layer
dependencies (a narrow layer followed by wide layer causes
distribution shifts), or if training did not sufficiently cover
mixed configurations, leading to brittle compositions.
Q: Given a target compute budget (latency),
how do you choose a per-layer granularity assignment? Formulate it
as an optimisation problem (knapsack / DP / greedy / learned
router).
A: Let each layer-width choice (g) have cost
(c_{,g}) and predicted utility (u_{,g}) (e.g., marginal loss
reduction). Optimise:\[
\max \sum_{\ell} u_{\ell,g_\ell} \quad \text{s.t.}\quad
\sum_{\ell} c_{\ell,g_\ell} \le C
\]A: Let each layer-width choice (g) have
cost (c_{,g}) and predicted utility (u_{,g}) (e.g., marginal loss
reduction). Optimise:\[\max \sum_{\ell}
u_{\ell,g_\ell} \quad \text{s.t.}\quad \sum_{\ell} c_{\ell,g_\ell}
\le C\] This is a knapsack-like problem; DP works if costs
are discretisedlarities (e.g., using first (h’) neurons for
intermediate (h’)). Why should intermediate prefixes be “good”?
What property of training makes that plausible?
A: Because training enforces a ranking: early
neurons must serve all submodels, so features become progressively
more specialised as index increases. If this monotonic
“importance” emerges, then any prefix tends to be functional. This
is plausible when the loss includes multiple widths and gradients
repeatedly pressure early neurons to be universally useful.
Q: If you implement intermediate widths, how
do you ensure shape alignment and avoid retracing/compiling too
many variants (TorchDynamo / XLA / TensorRT)?
A: Use a universal model that always materialises
max shapes but masks/slices within kernels; or bucket widths into
a small set and reuse compiled graphs; or use dynamic shape
compilation if supported. For TensorRT-like systems, export a few
discrete width profiles, not every possible (h’).
C. Practical extraction
Q: What is the cleanest API for extraction:
build a new model object with per-layer widths vs keep one
universal model and pass a “width map” at runtime. Which is better
for deployment? for training?
A: For training, universal model + runtime width
map is cleaner (single set of params, avoids duplication). For
deployment, exporting fixed extracted submodels can simplify
compilation and runtime (static shapes), but universal+map is
great for dynamic serving if the backend supports it. A hybrid is
common: universal checkpoint + a small set of exported
profiles.
Q: How would you checkpoint extracted
submodels: store full weights + metadata, or store only universal
+ recipe?
A: Prefer universal + recipe (width map +
versioning) to avoid duplication and keep consistency. For
production, you might also store “materialised” weights for
specific profiles to speed loading/compilation; but keep the
recipe path for reproducibility.
4) Deployment & “behavioral consistency”
Q: Define two measures: (i) percent matching
generated tokens for same prefix, and (ii) KL divergence of
smaller model outputs vs larger model outputs. What are the
gotchas (sampling temperature, top-p, EOS handling)?
A: (i) Token match rate: run greedy decoding or
fixed-seed sampling and compute fraction of tokens where outputs
match. Gotchas: sampling stochasticity, tie-breaking, EOS
differences; must fix decoding settings, temperature, top-p/top-k,
and prompt formatting.
(ii) KL: compute KL between probability distributions over vocab
at each step; gotchas: must compare aligned steps (same
context), avoid numeric underflow (use log-softmax), and handle
tokens masked by top-p sampling (evaluate on full logits if
possible).
Q: Why does high consistency matter for
speculative decoding / model cascades / cross-platform serving
drift?
A: Speculative decoding requires the verifier to
accept draft tokens; acceptance depends directly on distributional
closeness. Cascades/early-exit require switching models without
changing outputs dramatically. Cross-platform drift: if mobile
uses small model and server uses large, you want user experience
stable and debuggable.
Q: If consistency is high in KL but low in
exact token matches, how do you interpret that?
A: Small differences in logits can flip the
argmax in competitive regions; KL can be low while argmax differs.
This suggests outputs are similar “in distribution” but decoding
is sensitive. Mitigations: use temperature smoothing, calibrate
small model, or focus on acceptance rates under the chosen
decoding scheme.
Q: For dynamic workloads, the paper suggests
token-based routing / elastic deployment. What’s the control
signal? How do you prevent oscillations (thrashing) across
widths?
A: Control signals: measured latency/queue depth,
per-token uncertainty (entropy), attention/FFN compute budget, or
task-specific difficulty signals. Prevent oscillations via
hysteresis (two thresholds), smoothing (EMA), minimum dwell time
per width, or a controller with a cost on switching.
5) Speculative decoding with MatLM
Q: Walk through speculative decoding precisely
(draft generation, parallel verification, rollback). Where does
draft–verifier mismatch cost time?
A: Steps: (1) Drafter generates a block of (k)
tokens. (2) Verifier runs on prompt + proposed tokens and checks
whether it would produce same tokens (or accepts via rejection
sampling criterion). (3) Accept prefix of tokens until first
mismatch; rollback remainder; continue. Cost arises when accept
rate is low: verifier work wasted on rejected tokens, plus
overhead of frequent rollbacks/synchronisation.
Q: The paper claims MatLM submodels can be
“more consistent” than independently trained baselines, improving
speculative speedups. What causes that, mechanistically?
A: Shared weights and joint training cause
smaller and larger models to share representations and logits more
closely. Independently trained models differ in hidden
representations and calibration, reducing accept rates.
MatFormer’s nesting effectively behaves like “built-in
distillation” across sizes.
Q: The paper also mentions sharing attention
cache across models from MatLM being feasible (but infeasible for
baselines). Explain exactly how that can be correct given
that FFNs differ and residual streams change. If it’s approximate,
what’s the approximation?
A: Exact cache sharing is correct
only if the inputs to attention (residual stream)
at each layer are identical between drafter and verifier. Since
FFNs differ, residuals generally differ, so exact cache sharing is
not generally valid. What can be valid: - share caches when the
same model verifies (trivial), or - share caches
up to layers where computation matches, or - approximate sharing
if differences are small and you accept slight drift (then it’s an
approximation, and must be validated).
Practically: the safest cache-sharing is within the same universal
model by selecting widths in a way that preserves attention
inputs—or by verifying using the same hidden states computed
once.
Q: If you used a Mix’n’Match model as a
drafter, how would you choose its per-layer widths to maximise
accept rate under a latency budget?
A: Choose widths that preserve “core” semantics:
keep early layers wider (stabilise representations) and shrink
later layers more (where features are more task-specific).
Optimise for accept rate by measuring KL/entropy alignment to
verifier per layer and running a small search over width schedules
under cost constraints.
Q: Design an ablation to separate gains from
(i) shared weights/consistency and (ii) any cache-sharing
trick.
A: Compare four conditions: 1) baseline
independent small+large (no cache sharing), 2) MatFormer
small+large (no cache sharing), 3) baseline with any approximate
cache reuse (if possible), 4) MatFormer with cache reuse.
Measure accept rate, verifier compute, and end-to-end latency. If
MatFormer beats baseline in (2) vs (1), it’s consistency; if (4)
adds extra gains over (2), it’s cache reuse.
6) MatLM experiments & scaling behaviour
Q: What are MatLM-{S,M,L,XL}? What changes
across them (only FFN ratio vs depth/width/heads)?
A: They are size tiers for the
largest model in the MatFormer family. Typically
changes include model width (d), depth (L), and sometimes FFN
ratio/head counts to reach parameter targets. The MatFormer
mechanism then yields nested smaller widths primarily by shrinking
FFNs. (Verify exact configs in the released checkpoints/config
files.)
Q: Evaluation suite: 26 English tasks split
into “GEN” vs “RANK”. Why split this way, and how does it change
sensitivity to small distribution shifts?
A: GEN tasks require generation quality under
decoding—more sensitive to calibration, sampling, and small logit
differences. RANK tasks (multiple-choice/ranking) are often
evaluated by scoring candidates, less sensitive to decoding
randomness and sometimes more stable across small distribution
shifts. The split helps separate “language modelling behaviour”
from “scoring behaviour.”
Q: What does it mean that Mix’n’Match yields
many models “on the optimal loss–compute curve at zero cost”? What
is “cost” here: train cost, tune cost, or evaluate cost?
A: “Zero cost” mainly means no additional
training: once you train the universal MatFormer, you can
extract many variants without retraining. There is still some
evaluation/selection cost if you want to choose
the best variant for a budget.
Q: Scaling law fit: paper fits a function of
non-embedding params (N) and tokens (D) and gets similar fitted
parameters for baseline vs MatFormer. What conclusion is
justified—and what is not justified?
A: Justified: MatFormer does not obviously break
scaling behaviour; its best-frontier points track similar trends.
Not justified: claiming universal equivalence of scaling constants
across all regimes or that MatFormer always matches baseline at
every compute point; scaling fits are empirical and limited by
measured ranges.
Q: The paper notes MatLM and baseline of same
size can have different FLOPs/step. Why, and how should scaling
analysis account for it?
A: Different FLOPs/step can arise from different
architectures (e.g., FFN ratio, head counts), or from extra
overhead of multi-width training. Scaling analysis should compare
models by actual compute (FLOPs) and tokens, not
just parameter count; ideally plot loss vs FLOPs or
wall-clock.
7) MatViT: vision & retrieval
Q: How is MatFormer applied to ViT encoders,
and what stays fixed?
A: Apply nesting primarily to the MLP (FFN)
within each transformer block, keeping patch embedding and
attention shapes fixed (or less frequently, nesting heads). The
key is maintaining consistent residual dimension so the encoder
remains compatible with downstream heads/retrieval pipelines.
Q: Training regimes mentioned: ViT-B/16 on
ImageNet-1K with AugReg; ViT-L/16 pretrain on ImageNet-21K then
finetune on ImageNet-1K (Scenic setup). Why does pretraining stage
matter for elasticity?
A: Pretraining learns general visual features and
stabilises representations. Elasticity benefits from a strong
“core” representation that small prefixes can reuse; pretraining
increases the chance that early neurons encode robust general
features, improving small-width performance and retrieval
stability.
Q: The paper claims smaller extracted encoders
preserve metric-space structure for adaptive
retrieval. What does “preserve structure” mean operationally?
Define a metric and an evaluation protocol.
A: Operationally: distances/similarities between
embeddings are approximately preserved. Metrics: Spearman
correlation of pairwise similarities, Procrustes alignment error,
or neighbourhood overlap (kNN consistency). Protocol: compute
corpus embeddings with XL, query embeddings with smaller widths,
measure recall@K / mAP against ground truth labels, and measure
embedding-space distortion measures.
Q: If your corpus embeddings are computed with
the universal model, what guarantees do you need for query
embeddings from smaller submodels to still retrieve well?
A: You need approximate
alignment: queries from smaller widths should map
near the same regions as XL queries. This is supported if early
features are shared and small widths approximate the XL embedding
function. In practice, verify by measuring neighbourhood overlap
and retrieval accuracy; if mismatch is high, you may need to
compute corpus embeddings for the same width or learn a small
alignment mapping.
Q: What breaks if you Mix’n’Match per-layer
widths arbitrarily in encoders used for retrieval? How would you
constrain width schedules to preserve isometry?
A: Arbitrary mixing can introduce layerwise
distribution shifts that distort embeddings, harming retrieval.
Constrain schedules to be smooth (avoid sudden
width jumps), prefer monotonic shrinkage (later layers narrower),
or restrict to a small menu of validated schedules; optionally
train with random width schedules to improve robustness.
8) MatFormer-OLMo repo: training/config/deployment
A. Training entrypoint & config system
Q: The repo’s training script is
scripts/train.py and supports distributed
training only (torchrun/Slurm). What design choices force
distributed-only? What would you change to enable single-GPU debug
mode?
A: Likely assumptions: initialisation of
distributed process group is mandatory, use of DDP/FSDP wrappers
everywhere, rank-aware dataloading, and config requiring
world-size. To enable single-GPU debug: allow
world_size=1, skip init_process_group if
not launched with torchrun, treat rank=0 as sole worker, and
disable sharded optimisers. Add a --single_process
flag and fall back to cuda:0.
Q: Config: first argument is a config file;
overrides use dot-notation flags like
--optimizer.learning_rate=.... How would you
implement this override system robustly (types, nested structs,
lists)?
A: Use a structured config system (e.g.,
OmegaConf/Hydra) or implement: 1) parse overrides as
key=value strings, 2) split key by dots into path, 3)
type-cast value using schema (dataclasses/Pydantic) or YAML parser
for scalars/lists, 4) update nested dict safely, validating keys
and types, and reporting unknown keys. Ensure list indices are
supported (e.g., layers.3.hidden=...) only if
needed.
Q: Repo flag: --matformer_factor.
Explain precisely: =1 baseline; =8
yields four granularities {h, h/2, h/4, h/8}. How is
“factor” mapped to widths?
A: Standard mapping: if
factor = 2^k, create widths (h/2^i) for (i=0..k). So
factor=8=2^3 → (i=0,1,2,3) → four widths.
Implementation likely computes k = int(log2(factor))
and generates h // (2**i) (with checks that divisions
are integral).
B. Data/tokenisation
Q: The models are trained on The
Pile tokenised with EleutherAI/gpt-neox-20b.
What pitfalls arise from tokenizer choice when comparing baselines
vs MatFormer variants?
A: Tokeniser affects: sequence lengths (tokens
per character), effective dataset size in tokens, OOV handling,
and perplexity comparability. If baselines use different
tokenisers, comparisons are confounded. Also ensure special tokens
(BOS/EOS) and truncation/padding rules match.
Q: How would you validate that your dataloader
is deterministic across ranks and resumes (shuffling, epoch
boundaries, worker seeds)?
A: Use DistributedSampler with fixed
seed and epoch set each epoch; set worker init seeds
deterministically; log first N batch sample IDs per rank; on
resume, verify global step aligns with sampler state; test by
running two short training runs and comparing batch hashes.
C. Checkpoints & inference API
Q: The repo releases checkpoints and shows
loading via Olmo.from_checkpoint() and
Tokenizer.from_checkpoint(). What must a checkpoint
contain to support extraction / Mix’n’Match at inference
time?
A: It must include: model architecture config
(layer count, d_model, FFN hidden, matformer_factor, activation,
rotary settings, etc.), weights for universal model (full
(W_1,W_2)), and metadata enabling width selection (list of
supported widths and mapping). For Mix’n’Match, you also want a
schema for per-layer width maps.
Q: The repo provides generate()
with beam search. Where would you hook speculative decoding in
this API: inside generate(), or as a separate
decoding module?
A: Prefer a separate decoding module that wraps
the model’s forward() and KV-cache interface, because
speculative decoding is a different algorithm with different
control flow. You can keep generate() as a generic
entrypoint that chooses between decoding strategies, but
implementing speculative inside the core beam search code can
become messy.
D. Repro & cluster engineering
Q: The README mentions environment scripts and
Slurm setup. What are the three most common silent misconfigs
you’d expect (paths, HF cache, rank env vars), and how do you make
them fail loudly?
A: Common: (1) wrong dataset path or insufficient
permissions, (2) HF cache collisions / missing tokenizer files,
(3) incorrect MASTER_ADDR/PORT, RANK,
WORLD_SIZE, or GPU visibility. Fail loudly via
explicit checks at startup: verify files exist, can read/write
cache, validate distributed env vars, and do a short all-reduce
sanity test.
Q: What metrics would you log to ensure all
granularities are learning (per-granularity loss curves, gradient
norms on shared vs exclusive weights, consistency metrics)?
A: Log: loss/perplexity per granularity, learning
rate, gradient norm split into shared-prefix and exclusive-tail
regions, cosine similarity between gradients from different
granularities, KL/token-match consistency between widths, and
accept-rate proxies if doing speculative tests.
9) devvrit/matformer repo: minimal/alternate implementation prompts
Q: You see modified_llama.py +
train.py. Describe the minimal changes needed to turn
a standard LLaMA FFN into a MatFormer FFN without breaking HF
weight loading.
A: Keep parameter tensors at max
width so HF loading works. In forward, slice prefixes by
chosen width: W1[:hg,:], W2[:,:hg]. Keep
module names/parameter shapes consistent with HF state dict. Add a
config flag matformer_factor and a runtime width
selection mechanism. Ensure outputs remain shape
(batch, seq, d_model).
Q: How do you expose per-layer granularity
choices cleanly through the forward signature (e.g.,
width_map, layer_cfg) while keeping HF
generation utilities working?
A: Use
model.forward(..., width_map=None) with default
None meaning full width. HF generate()
passes extra kwargs via model_kwargs; ensure
prepare_inputs_for_generation forwards
width_map. Store width_map in the model
as a mutable attribute only if you must; explicit kwargs is
cleaner.
Q: If you trained only with granularities
{S,M,L,XL}, what tests would convince you Mix’n’Match models are
“real” and not evaluation noise?
A: (1) Evaluate many random Mix’n’Match schedules
and show smooth accuracy–compute curve with low variance. (2)
Compare to controls: random neuron permutations or random width
schedules without training. (3) Check consistency metrics and
downstream task scores; if performance correlates with predicted
compute and is stable across seeds, it’s real.
10) Failure modes, debugging, and rigorous ablations
Q: Gradient leakage: how do
you verify that the smallest model never updates “exclusive”
large-width weights (it shouldn’t touch them)?
A: After a backward pass using only small-model
loss, inspect .grad on (W_1[h_s:,:]) and
(W_2[:,h_s:]); they should be exactly zero (or numerically ~0).
Add unit tests and runtime assertions in debug mode.
Q: Dominance: if large model
loss dominates and small models stagnate, what measurable signals
show this early?
A: Small-model loss plateaus; gradient norms on
shared prefix reflect mostly large-model gradients; negative
gradient cosine similarity between small and large grows;
consistency worsens. Also watch effective learning rate for shared
weights (large gradients overwhelm small contributions).
Q: Consistency vs quality
trade: can you increase consistency at the cost of
accuracy? How (e.g., explicit KL regulariser to XL), and should
you?
A: Yes: add a KL term (,(p_g|p_{XL})) to force
small outputs to match XL. This can reduce small model’s ability
to diverge when it would help its own accuracy. Whether you should
depends on product goals: for speculative decoding and cascades,
consistency can be worth a small accuracy hit.
Q: Layerwise width
allocation: is it better to shrink early layers or late
layers under fixed FLOPs? Give a hypothesis and a way to
test.
A: Hypothesis: keep early layers wider because
they build core representations; shrink later layers more because
they specialise. Test by fixing total cost and comparing
schedules: early-narrow vs late-narrow vs uniform; measure
downstream tasks and consistency.
Q: Distribution shift: do
Mix’n’Match models degrade more under out-of-domain evaluation
than explicitly trained widths? Why might that happen?
A: They can, because mixed schedules may yield
representation distributions not seen during training. Explicitly
trained widths are “on-manifold” for that width; Mix’n’Match can
be off-manifold. Mitigate by training with random schedules
(stochastic width per layer) or adding regularisers.
11) Surrounding questions (positioning & theory)
Comparisons & positioning
Q: Compare MatFormer to: OFA/slimmable nets,
pruning, distillation, quantisation, MoE routing, early-exit,
speculative decoding with separately trained models. Where does
MatFormer sit on the training-cost vs inference-flexibility
frontier?
A: MatFormer is closest to slimmable/OFA ideas
but adapted to Transformers and aimed at generative
consistency. Versus pruning: MatFormer bakes in structure
rather than pruning post hoc. Versus distillation: it’s like
multi-student training inside one model, but with shared weights.
Versus quantisation: orthogonal (can combine). Versus MoE: MoE
adds conditional compute but more parameters/complexity; MatFormer
changes width deterministically. Early-exit changes depth;
MatFormer changes width (can be combined). For speculative
decoding, MatFormer provides a natural family of consistent
drafter/verifier pairs at low additional training cost.
Q: MatFormer vs “train a dense model then do
structured pruning”: which gives better behavioral
consistency guarantees and why?
A: MatFormer generally offers stronger
consistency because smaller models are trained jointly and share
weights. Post-hoc pruning changes the function after training and
can induce distributional drift; you can fine-tune, but
consistency is not guaranteed unless explicitly optimised.
Theoretical/empirical grounding
Q: Why should the “prefix” of neurons be the
most important? Is there a plausible connection to feature
learning / lottery-ticket style arguments?
A: The prefix becomes important because it is the
intersection of all submodels: it receives
gradients from every width. This is akin to learning a “core”
subnetwork (lottery-ticket-like) that works under multiple
constraints. The ordering is emergent, not fundamental—training
pressure creates it.
Q: What does it mean for MatLM scaling trends
to “mimic” vanilla transformers? What evidence would falsify that
claim?
A: It means loss vs compute/params follows
similar functional forms and exponents. Falsification: MatFormer
deviates systematically—e.g., worse asymptotic scaling, different
optimal data/param ratio, or consistent gaps not explained by
compute differences.
Product/deployment design
Q: Suppose you’re serving 3 latency tiers
(mobile, edge GPU, datacenter). How do you ship MatFormer: one
universal checkpoint + recipes, or export three fixed
submodels?
A: If backend supports dynamic width selection
and you want consistency + flexible routing, ship universal +
recipes. If you want maximum runtime simplicity and static
compilation, export fixed profiles. Many production systems ship
both: universal for experimentation, fixed for stable
deployment.
Q: How would you build an adaptive controller
that chooses a Mix’n’Match width schedule per request/token while
enforcing SLOs?
A: Define a budget (C) per request/token. Use
signals like queue depth and token entropy; choose a schedule via
DP/greedy or a learned policy. Add hysteresis and penalties for
switching. Validate with offline simulation and online A/B tests,
measuring latency tail and quality.
12) Practical take-home exercises (with expected answers)
Q: Implement a MatFormer FFN module in PyTorch
(4 granularities), with unit tests for slicing, gradient tests,
and Mix’n’Match extraction API. What would “done” look like?
A: “Done” includes:
- MatFormerFFN(d, h, widths=[h, h/2, h/4, h/8])
implementing prefix slicing,
- tests: output equality with reference implementation; gradient
zero outside prefix for small loss; deterministic behaviour,
- extraction: extract(width_map) returns a wrapper
model that uses chosen widths; plus docs/examples.
Q: Profiling task: benchmark forward latency
vs width schedule; find the knee where attention dominates. What
should the report include?
A: Include: per-layer and total latency breakdown
(attention vs FFN vs other), latency vs sequence length and width,
GPU utilisation, and identify regime where shrinking FFN stops
helping. Provide recommended width tiers for typical sequence
lengths.
Q: Spec decoding prototype: drafter=small
submodel, verifier=XL; log accept rate, rollback rate, end-to-end
latency. What indicates success?
A: Success: accept rate high enough that verifier
amortises; end-to-end tokens/sec improves vs baseline decoding
with XL alone; quality remains comparable (task eval or human
checks). Report curves vs drafter size and token block length
(k).
Q: Retrieval task (encoder): compute corpus
embeddings with XL; query with smaller widths; quantify recall@K
vs compute. What indicates MatFormer helps?
A: If recall@K degrades gracefully with width and
small widths still retrieve well compared to independently trained
small encoders, then MatFormer preserves embedding structure. Also
show embedding distortion metrics and compute savings.