Skip to main content

Overview

The nanochat GPT model is a simplified, modern transformer architecture with several notable features:
  • Rotary positional embeddings (RoPE) instead of learned positional embeddings
  • Query-Key normalization for training stability
  • Untied weights for token embedding and language model head
  • ReLU² activation in MLP layers
  • Normalization after token embedding
  • No learnable parameters in RMSNorm
  • No bias in linear layers
  • Group-Query Attention (GQA) for efficient inference
  • Sliding window attention pattern support
  • Flash Attention 3 integration with automatic fallback

GPTConfig

The model is configured using a dataclass with the following fields:
@dataclass
class GPTConfig:
    sequence_len: int = 2048      # Maximum sequence length
    vocab_size: int = 32768        # Vocabulary size
    n_layer: int = 12              # Number of transformer layers
    n_head: int = 6                # Number of query heads
    n_kv_head: int = 6             # Number of key/value heads (GQA)
    n_embd: int = 768              # Embedding dimension
    window_pattern: str = "SSSL"  # Sliding window attention pattern

Window Pattern

The window_pattern string controls sliding window attention across layers:
  • L (Long): Full context attention (sequence_len tokens)
  • S (Short): Half context attention (sequence_len // 2 tokens)
The pattern is tiled across layers. For example:
  • "L" = all layers use full context
  • "SL" = alternating short/long across layers
  • "SSSL" = two short, one short, one long (repeating)
The final layer always gets full context regardless of pattern.

Architecture Components

Rotary Embeddings (RoPE)

The model uses rotary positional embeddings instead of learned positional embeddings:
def apply_rotary_emb(x, cos, sin):
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3)
Key properties:
  • Base frequency: 10000 (standard GPT-style)
  • Applied to both queries and keys
  • Provides relative positional information
  • Pre-computed up to 10x the configured sequence length
  • Stored in bfloat16 format
Reference: gpt.py:243-258

QK Normalization

Both queries and keys are normalized after RoPE application:
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)  # QK norm
This uses functional RMSNorm with no learnable parameters:
def norm(x):
    return F.rms_norm(x, (x.size(-1),))
Reference: gpt.py:42-44, gpt.py:94

Group-Query Attention (GQA)

GQA reduces memory usage during inference by sharing key/value heads across multiple query heads:
  • n_head: Number of query heads (e.g., 6)
  • n_kv_head: Number of key/value heads (e.g., 2 or 3)
  • Constraint: n_head % n_kv_head == 0
When n_kv_head == n_head, this is standard multi-head attention. When n_kv_head < n_head, keys and values are shared across groups of query heads. Reference: gpt.py:63-68

MLP with ReLU²

The MLP uses a squared ReLU activation:
class MLP(nn.Module):
    def forward(self, x):
        x = self.c_fc(x)           # (n_embd, 4 * n_embd)
        x = F.relu(x).square()     # ReLU²
        x = self.c_proj(x)         # (4 * n_embd, n_embd)
        return x
This provides:
  • Stronger non-linearity than standard ReLU
  • Better gradient flow properties
  • No additional parameters
Reference: gpt.py:121-131

Sliding Window Attention

The model supports configurable sliding window attention to reduce memory and compute:
  • Window sizes computed per-layer from window_pattern
  • Format: (left, right) tuple
    • left: tokens before current position (-1 = unlimited)
    • right: tokens after current position (0 for causal)
  • Examples:
    • (-1, 0): full causal attention
    • (1024, 0): attend to last 1024 tokens only
Reference: gpt.py:260-287

Value Embeddings (ResFormer-style)

Alternating layers include value embeddings that are mixed into the attention values:
if ve is not None:
    ve = ve.view(B, T, self.n_kv_head, self.head_dim)
    gate = 2 * torch.sigmoid(self.ve_gate(x[..., :32]))
    v = v + gate.unsqueeze(-1) * ve
  • Applied to alternating layers (last layer always included)
  • Input-dependent gating per head
  • Gate range: (0, 2) via 2 * sigmoid
  • Uses first 32 channels of input for gate computation
Reference: gpt.py:86-89, gpt.py:174-177

Untied Weights

The token embedding (wte) and language model head (lm_head) use separate weight matrices:
self.transformer.wte = nn.Embedding(vocab_size, n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
This allows:
  • Different initialization strategies
  • Independent learning rates (unembedding_lr vs embedding_lr)
  • Better optimization dynamics
Reference: gpt.py:164-167

Per-Layer Scalars

The model includes learnable per-layer scaling parameters:
self.resid_lambdas = nn.Parameter(torch.ones(n_layer))   # init 1.0
self.x0_lambdas = nn.Parameter(torch.zeros(n_layer))     # init 0.1

# In forward pass:
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
  • resid_lambdas: scales residual stream (init 1.0 = neutral)
  • x0_lambdas: blends initial embedding back in (init 0.1 = small influence)
  • Separate optimizer treatment for each
Reference: gpt.py:172-173, gpt.py:404

Weight Initialization

The model uses a custom initialization scheme optimized for stability:
ParameterDistributionStd Dev
wte (embedding)Normal1.0
lm_headNormal0.001
attn.c_q, c_k, c_vUniform1/√n_embd
attn.c_projZeros-
mlp.c_fcUniform1/√n_embd
mlp.c_projZeros-
value_embedsUniform1/√n_embd
ve_gateZeros-
resid_lambdas-1.0
x0_lambdas-0.1
Projection layers (c_proj) are initialized to zero for a clean residual path at initialization. Reference: gpt.py:188-235

Vocab Padding

The vocabulary size is padded to a multiple of 64 for efficiency:
padded_vocab_size = ((vocab_size + 63) // 64) * 64
This improves:
  • Tensor core utilization
  • Distributed training performance
  • Memory access patterns
Logits are cropped back to the true vocabulary size before loss computation. Reference: gpt.py:159-162, gpt.py:412

Logit Softcapping

Logits are smoothly capped to prevent extreme values:
softcap = 15
logits = softcap * torch.tanh(logits / softcap)
This provides:
  • More stable training
  • Better numerical properties
  • Bounded logit range: [-15, 15]
Reference: gpt.py:410-414

Memory Optimizations

BFloat16 Embeddings

Embedding layers are cast to bfloat16 after initialization:
self.transformer.wte.to(dtype=torch.bfloat16)
for ve in self.value_embeds.values():
    ve.to(dtype=torch.bfloat16)
This reduces memory usage without impacting optimizer behavior. Reference: gpt.py:238-241

FLOPs Estimation

The model includes a method to estimate training FLOPs:
flops_per_token = model.estimate_flops()
This accounts for:
  • 6 FLOPs per weight parameter (2 forward + 4 backward)
  • Attention QK^T matmul: 12 * h * q * effective_seq_len per layer
  • Sliding window adjustments (effective_seq_len varies per layer)
  • Excludes embeddings and non-matmul operations
Reference: gpt.py:292-317

Optimizer

MuonAdamW optimizer details

Flash Attention

Flash Attention 3 integration

Build docs developers (and LLMs) love