Overview
The nanochat GPT model is a simplified, modern transformer architecture with several notable features:- Rotary positional embeddings (RoPE) instead of learned positional embeddings
- Query-Key normalization for training stability
- Untied weights for token embedding and language model head
- ReLU² activation in MLP layers
- Normalization after token embedding
- No learnable parameters in RMSNorm
- No bias in linear layers
- Group-Query Attention (GQA) for efficient inference
- Sliding window attention pattern support
- Flash Attention 3 integration with automatic fallback
GPTConfig
The model is configured using a dataclass with the following fields:Window Pattern
Thewindow_pattern string controls sliding window attention across layers:
- L (Long): Full context attention (sequence_len tokens)
- S (Short): Half context attention (sequence_len // 2 tokens)
"L"= all layers use full context"SL"= alternating short/long across layers"SSSL"= two short, one short, one long (repeating)
Architecture Components
Rotary Embeddings (RoPE)
The model uses rotary positional embeddings instead of learned positional embeddings:- Base frequency: 10000 (standard GPT-style)
- Applied to both queries and keys
- Provides relative positional information
- Pre-computed up to 10x the configured sequence length
- Stored in bfloat16 format
QK Normalization
Both queries and keys are normalized after RoPE application:Group-Query Attention (GQA)
GQA reduces memory usage during inference by sharing key/value heads across multiple query heads:n_head: Number of query heads (e.g., 6)n_kv_head: Number of key/value heads (e.g., 2 or 3)- Constraint:
n_head % n_kv_head == 0
n_kv_head == n_head, this is standard multi-head attention. When n_kv_head < n_head, keys and values are shared across groups of query heads.
Reference: gpt.py:63-68
MLP with ReLU²
The MLP uses a squared ReLU activation:- Stronger non-linearity than standard ReLU
- Better gradient flow properties
- No additional parameters
Sliding Window Attention
The model supports configurable sliding window attention to reduce memory and compute:- Window sizes computed per-layer from
window_pattern - Format:
(left, right)tupleleft: tokens before current position (-1 = unlimited)right: tokens after current position (0 for causal)
- Examples:
(-1, 0): full causal attention(1024, 0): attend to last 1024 tokens only
Value Embeddings (ResFormer-style)
Alternating layers include value embeddings that are mixed into the attention values:- Applied to alternating layers (last layer always included)
- Input-dependent gating per head
- Gate range: (0, 2) via
2 * sigmoid - Uses first 32 channels of input for gate computation
Untied Weights
The token embedding (wte) and language model head (lm_head) use separate weight matrices:
- Different initialization strategies
- Independent learning rates (unembedding_lr vs embedding_lr)
- Better optimization dynamics
Per-Layer Scalars
The model includes learnable per-layer scaling parameters:resid_lambdas: scales residual stream (init 1.0 = neutral)x0_lambdas: blends initial embedding back in (init 0.1 = small influence)- Separate optimizer treatment for each
Weight Initialization
The model uses a custom initialization scheme optimized for stability:| Parameter | Distribution | Std Dev |
|---|---|---|
| wte (embedding) | Normal | 1.0 |
| lm_head | Normal | 0.001 |
| attn.c_q, c_k, c_v | Uniform | 1/√n_embd |
| attn.c_proj | Zeros | - |
| mlp.c_fc | Uniform | 1/√n_embd |
| mlp.c_proj | Zeros | - |
| value_embeds | Uniform | 1/√n_embd |
| ve_gate | Zeros | - |
| resid_lambdas | - | 1.0 |
| x0_lambdas | - | 0.1 |
c_proj) are initialized to zero for a clean residual path at initialization.
Reference: gpt.py:188-235
Vocab Padding
The vocabulary size is padded to a multiple of 64 for efficiency:- Tensor core utilization
- Distributed training performance
- Memory access patterns
Logit Softcapping
Logits are smoothly capped to prevent extreme values:- More stable training
- Better numerical properties
- Bounded logit range: [-15, 15]
Memory Optimizations
BFloat16 Embeddings
Embedding layers are cast to bfloat16 after initialization:FLOPs Estimation
The model includes a method to estimate training FLOPs:- 6 FLOPs per weight parameter (2 forward + 4 backward)
- Attention QK^T matmul: 12 * h * q * effective_seq_len per layer
- Sliding window adjustments (effective_seq_len varies per layer)
- Excludes embeddings and non-matmul operations
Related
Optimizer
MuonAdamW optimizer details
Flash Attention
Flash Attention 3 integration