Skip to main content
The baseline model in train_gpt.py is a causal transformer with several modern refinements layered on top of a standard decoder-only GPT stack. The architecture is explicitly designed to fit comfortably within the 16 MB artifact budget.

Default hyperparameters

All values are set in the Hyperparameters class and can be overridden via environment variables:
ParameterDefaultEnvironment variable
Vocabulary size1024VOCAB_SIZE
Layers9NUM_LAYERS
Model dimension512MODEL_DIM
Attention heads8NUM_HEADS
KV heads (GQA)4NUM_KV_HEADS
MLP expansionMLP_MULT
Tied embeddingstrueTIE_EMBEDDINGS
RoPE base10000.0ROPE_BASE
Logit softcap30.0LOGIT_SOFTCAP

U-Net skip connections

The 9 transformer blocks are split into an encoder half and a decoder half. The encoder layers store residual activations on a stack; the decoder layers pop them in reverse order and add them back with learned per-dimension weights.
# From GPT.__init__
self.num_encoder_layers = num_layers // 2      # 4 for default 9-layer config
self.num_decoder_layers = num_layers - self.num_encoder_layers  # 5
self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers)  # 4
self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32))
The forward method implements the skip-connection logic:
def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
    x = self.tok_emb(input_ids)
    x = F.rms_norm(x, (x.size(-1),))
    x0 = x
    skips: list[Tensor] = []

    # First half stores skips; second half reuses them in reverse order.
    for i in range(self.num_encoder_layers):
        x = self.blocks[i](x, x0)
        skips.append(x)
    for i in range(self.num_decoder_layers):
        if skips:
            x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop()
        x = self.blocks[self.num_encoder_layers + i](x, x0)

    x = self.final_norm(x).reshape(-1, x.size(-1))
    targets = target_ids.reshape(-1)
    if self.tie_embeddings:
        logits_proj = F.linear(x, self.tok_emb.weight)
    else:
        logits_proj = self.lm_head(x)
    logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
    return F.cross_entropy(logits.float(), targets, reduction="mean")
The U-Net pattern enables the decoder layers to directly access early-layer representations without having to reconstruct them through a long residual chain. This is especially useful in small models where each layer has limited capacity.

Block internals

Each Block wraps attention and MLP sub-layers with per-dimension scale factors and a residual mix gate:
def forward(self, x: Tensor, x0: Tensor) -> Tensor:
    mix = self.resid_mix.to(dtype=x.dtype)
    x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
    attn_out = self.attn(self.attn_norm(x))
    x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out
    x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
    return x
The x0 argument is the post-embedding state from the very first token embedding, passed unchanged to every block. The resid_mix gate blends the running residual x with this original embedding before each sub-layer.

Key components

Implements Grouped Query Attention (GQA): 8 query heads share 4 key/value heads. Queries and keys are RMS-normalised before computing attention, and a learnable per-head scalar q_gain is applied to queries after RoPE:
q = F.rms_norm(q, (q.size(-1),))
k = F.rms_norm(k, (k.size(-1),))
cos, sin = self.rotary(seqlen, x.device, q.dtype)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]
y = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=None,
    is_causal=True,
    enable_gqa=(self.num_kv_heads != self.num_heads),
)
q_gain is initialised to QK_GAIN_INIT=1.5 and learned independently per head.
The MLP uses a squared ReLU activation (relu²):
def forward(self, x: Tensor) -> Tensor:
    x = torch.relu(self.fc(x))
    return self.proj(x.square())
Hidden dimension is MLP_MULT × model_dim (default: 2 × 512 = 1024). The projection weight is zero-initialised (_zero_init = True) so each block starts as an identity through its residual path.
RMSNorm wraps F.rms_norm without a learnable weight vector:
def forward(self, x: Tensor) -> Tensor:
    return F.rms_norm(x, (x.size(-1),), eps=self.eps)
Scale factors are instead handled by the per-dimension attn_scale and mlp_scale parameters in each Block.
The Rotary module precomputes cosine/sine tables and caches them for the current sequence length and device:
def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype):
    t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
    freqs = torch.outer(t, self.inv_freq.to(device))
    self._cos_cached = freqs.cos()[None, None, :, :]
    self._sin_cached = freqs.sin()[None, None, :, :]
The cache is invalidated whenever seq_len or device changes.
Each block holds three control tensors kept in fp32 during training:
  • resid_mix — shape (2, dim), blends current residual with original embedding x0
  • attn_scale — shape (dim,), per-dimension scale applied to attention output before adding to residual
  • mlp_scale — shape (dim,), per-dimension scale applied to MLP output before adding to residual
These are classified as “control tensors” and are kept in fp32 during quantization (see Model Quantization).

Logit softcap

Before computing cross-entropy loss, logits are passed through a tanh softcap:
logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
With LOGIT_SOFTCAP=30.0, this smoothly clamps logits to the range (−30, +30), preventing extreme confidence on any single token and stabilising training.

Tied embeddings

When TIE_EMBEDDINGS=1 (the default), the LM head reuses the input embedding matrix directly:
if self.tie_embeddings:
    logits_proj = F.linear(x, self.tok_emb.weight)
This halves the parameter cost of the vocabulary table — critical for a 1024-token vocabulary at 512 dimensions (512 × 1024 × 2 bytes in bf16 = 1 MB if separate, 0.5 MB if tied).

Optimizers

The model uses two optimizers in tandem:

Muon

Applied to all 2D matrix parameters in transformer blocks (weight matrices that are not control tensors). Uses Newton-Schulz orthogonalisation to produce unit-spectral-norm updates. Learning rate: MATRIX_LR=0.04.

Adam

Applied to token embeddings (EMBED_LR or TIED_EMBED_LR), the untied LM head (HEAD_LR=0.008), and all scalar/vector parameters including control tensors (SCALAR_LR=0.04).