Overview
TheGPT class implements a decoder-only transformer model with several modern improvements over the original GPT architecture.
GPTConfig
Configuration dataclass for the GPT model.Maximum context length for the model
Size of the vocabulary (typically 2^15)
Number of transformer layers (depth). This is the primary complexity dial.
Number of query attention heads
Number of key/value heads for Group-Query Attention (GQA). Set equal to
n_head for multi-head attention, or less for GQA.Model embedding dimension (width). Typically calculated as
depth * aspect_ratio.Sliding window attention pattern tiled across layers:
L= Long (full context)S= Short (half context)- Final layer always uses full context
"L" (all full), "SL" (alternating), "SSSL" (two short, two long)GPT Class
Architecture Features
The nanochat GPT implementation includes several modern improvements:Rotary Embeddings
RoPE for relative positional encoding (no learned positional embeddings)
QK Normalization
Normalizes queries and keys for stable attention
Untied Weights
Separate token embedding and language model head weights
ReLU² Activation
Squared ReLU activation in MLP layers
Group-Query Attention
GQA support for efficient inference with KV cache
Flash Attention 3
Automatic FA3 on Hopper+ GPUs, SDPA fallback elsewhere
Sliding Window
Configurable attention window patterns per layer
No Bias Terms
All linear layers use bias=False
Forward Pass
Input token indices of shape
(B, T) where B is batch size and T is sequence lengthTarget token indices for computing loss. If provided, returns
(logits, loss), otherwise (logits, None)Key-value cache for efficient autoregressive generation. Used by the Engine class.
Output logits of shape
(B, T, vocab_size)Cross-entropy loss if targets provided, otherwise None
Optimizer Configuration
The model provides aconfigure_optimizers() method that groups parameters for the MuonAdamW optimizer:
- Muon group: Weight matrices (Muon optimizer)
- Adam embedding: Token embeddings (AdamW)
- Adam unembedding: Language model head (AdamW)
- Adam scalar: Scalar parameters like gates (AdamW)
Helper Functions
norm(x)
Purely functional RMSNorm with no learnable parameters:apply_rotary_emb(x, cos, sin)
Applies rotary position embeddings:has_ve(layer_idx, n_layer)
Determines if a layer should have Value Embedding:See Also
- GPT Model Architecture - Detailed architecture overview
- Flash Attention - Flash Attention 3 integration
- Optimizer - MuonAdamW optimizer details