train_gpt.py is a causal transformer with several modern refinements layered on top of a standard decoder-only GPT stack. The architecture is explicitly designed to fit comfortably within the 16 MB artifact budget.
Default hyperparameters
All values are set in theHyperparameters class and can be overridden via environment variables:
| Parameter | Default | Environment variable |
|---|---|---|
| Vocabulary size | 1024 | VOCAB_SIZE |
| Layers | 9 | NUM_LAYERS |
| Model dimension | 512 | MODEL_DIM |
| Attention heads | 8 | NUM_HEADS |
| KV heads (GQA) | 4 | NUM_KV_HEADS |
| MLP expansion | 2× | MLP_MULT |
| Tied embeddings | true | TIE_EMBEDDINGS |
| RoPE base | 10000.0 | ROPE_BASE |
| Logit softcap | 30.0 | LOGIT_SOFTCAP |
U-Net skip connections
The 9 transformer blocks are split into an encoder half and a decoder half. The encoder layers store residual activations on a stack; the decoder layers pop them in reverse order and add them back with learned per-dimension weights.forward method implements the skip-connection logic:
The U-Net pattern enables the decoder layers to directly access early-layer representations without having to reconstruct them through a long residual chain. This is especially useful in small models where each layer has limited capacity.
Block internals
EachBlock wraps attention and MLP sub-layers with per-dimension scale factors and a residual mix gate:
x0 argument is the post-embedding state from the very first token embedding, passed unchanged to every block. The resid_mix gate blends the running residual x with this original embedding before each sub-layer.
Key components
CausalSelfAttention — GQA with RoPE and QK RMSNorm
CausalSelfAttention — GQA with RoPE and QK RMSNorm
Implements Grouped Query Attention (GQA): 8 query heads share 4 key/value heads. Queries and keys are RMS-normalised before computing attention, and a learnable per-head scalar
q_gain is applied to queries after RoPE:q_gain is initialised to QK_GAIN_INIT=1.5 and learned independently per head.MLP — relu² nonlinearity
MLP — relu² nonlinearity
The MLP uses a squared ReLU activation (Hidden dimension is
relu²):MLP_MULT × model_dim (default: 2 × 512 = 1024). The projection weight is zero-initialised (_zero_init = True) so each block starts as an identity through its residual path.RMSNorm — no learnable scale
RMSNorm — no learnable scale
RMSNorm wraps F.rms_norm without a learnable weight vector:attn_scale and mlp_scale parameters in each Block.Rotary — RoPE with cached cos/sin tables
Rotary — RoPE with cached cos/sin tables
The The cache is invalidated whenever
Rotary module precomputes cosine/sine tables and caches them for the current sequence length and device:seq_len or device changes.Block — resid_mix gate and per-dimension scale factors
Block — resid_mix gate and per-dimension scale factors
Each block holds three control tensors kept in fp32 during training:
resid_mix— shape(2, dim), blends current residual with original embeddingx0attn_scale— shape(dim,), per-dimension scale applied to attention output before adding to residualmlp_scale— shape(dim,), per-dimension scale applied to MLP output before adding to residual
Logit softcap
Before computing cross-entropy loss, logits are passed through a tanh softcap:LOGIT_SOFTCAP=30.0, this smoothly clamps logits to the range (−30, +30), preventing extreme confidence on any single token and stabilising training.
Tied embeddings
WhenTIE_EMBEDDINGS=1 (the default), the LM head reuses the input embedding matrix directly:
Optimizers
The model uses two optimizers in tandem:Muon
Applied to all 2D matrix parameters in transformer blocks (weight matrices that are not control tensors). Uses Newton-Schulz orthogonalisation to produce unit-spectral-norm updates. Learning rate:
MATRIX_LR=0.04.Adam
Applied to token embeddings (
EMBED_LR or TIED_EMBED_LR), the untied LM head (HEAD_LR=0.008), and all scalar/vector parameters including control tensors (SCALAR_LR=0.04).