Overview
Nanochat uses MuonAdamW, a hybrid optimizer that combines two different optimization algorithms:- Muon (Momentum Orthogonalized by Newton-schulz): For 2D matrix parameters (weights in attention and MLP layers)
- AdamW: For embeddings, language model head, and scalar parameters
Architecture-Specific Parameter Groups
The GPT model’s optimizer setup creates distinct parameter groups with different learning rates:AdamW Groups
| Group | Parameters | Base LR | Beta1 | Beta2 | Weight Decay |
|---|---|---|---|---|---|
| Unembedding | lm_head | 0.004 | 0.8 | 0.95 | 0.0 |
| Embeddings | wte, value_embeds | 0.2 | 0.8 | 0.95 | 0.0 |
| Resid scalars | resid_lambdas | 0.005 | 0.8 | 0.95 | 0.0 |
| X0 scalars | x0_lambdas | 0.5 | 0.96 | 0.95 | 0.0 |
∝1/√(n_embd/768) to maintain consistent behavior across model sizes.
Muon Groups
Matrix parameters are grouped by shape and optimized together:| Parameters | Base LR | Momentum | NS Steps | Beta2 | Weight Decay |
|---|---|---|---|---|---|
| All 2D matrices | 0.02 | 0.95 | 5 | 0.95 | configurable |
- Attention:
c_q,c_k,c_v,c_proj,ve_gate - MLP:
c_fc,c_proj
max(1.0, rows/cols)^0.5 per group to account for matrix aspect ratio.
Reference: gpt.py:348-386
Muon Optimizer Details
Algorithm Overview
Muon performs momentum-based optimization followed by orthogonalization:- Nesterov Momentum: Apply momentum to gradients
- Polar Express: Orthogonalize the update using Newton-Schulz iteration
- Variance Reduction (NorMuon): Normalize per-neuron update scales
- Cautious Weight Decay: Apply decay only when update and parameter agree in sign
Step 1: Nesterov Momentum
- Momentum coefficient: 0.95
- Uses Nesterov-style lookahead
- Accumulated per parameter
Step 2: Polar Express Orthogonalization
Replaces Newton-Schulz with Polar Express for better convergence:- 5 iterations (configurable via
ns_steps) - Automatically handles tall vs. wide matrices
- Computed in bfloat16 for efficiency
- Coefficients optimized for safety_factor=0.02, cushion=2
U S' V^T where S' has diagonal entries ~ Uniform(0.5, 1.5), which empirically works as well as true UV^T orthogonalization.
Reference: optim.py:115-127, optim.py:80-88
Step 3: Variance Reduction (NorMuon)
Normalizes update magnitudes across neurons/columns:- Second moment tracked per row or column (factored, not full matrix)
- Beta2: 0.95
- Preserves overall gradient norm while equalizing per-neuron magnitudes
Step 4: Cautious Weight Decay
Applies weight decay only when the update and parameter agree in sign:AdamW Optimizer Details
Standard AdamW with decoupled weight decay:Implementations
MuonAdamW (Single GPU)
Baseline implementation for single GPU training:- Each parameter optimized individually (AdamW)
- Matrix parameters stacked by shape for efficient Muon steps
- No distributed communication
- Used for debugging and small-scale experiments
DistMuonAdamW (Multi-GPU)
Optimized for distributed training with ZeRO-2 style sharding: AdamW Communication:- Small params (<1024 elements): all_reduce gradients, replicate state
- Large params: reduce_scatter gradients, shard state across ranks, all_gather updates
- Stack all parameters in group into single tensor:
(K, *shape) - Divide K parameters across N ranks: each owns
ceil(K/N)parameters - reduce_scatter stacked gradients → each rank gets its chunk
- Each rank computes Muon update for its chunk only
- all_gather updated parameters back to all ranks
- Optimizer state sharded by chunk (momentum_buffer, second_momentum_buffer)
- Launch all async reduce operations (don’t wait)
- For each group: wait for reduce → compute update → launch gather
- Wait for all gathers → copy parameters back
Fused Kernels
Both optimizers use@torch.compile fused kernels for efficiency:
adamw_step_fused
Single compiled graph for:- Weight decay
- Momentum update
- Bias correction
- Parameter update
muon_step_fused
Single compiled graph for:- Nesterov momentum
- Polar Express (5 iterations)
- Variance reduction
- Cautious update
0-D CPU Tensors
Hyperparameters (lr, beta1, beta2, etc.) are stored as 0-D CPU tensors:Parameter Grouping Strategy
Muon requires all parameters in a group to have the same shape. The GPT model achieves this naturally:(n_embd, n_embd): attention projections(n_embd, 4*n_embd): MLP up-projection(4*n_embd, n_embd): MLP down-projection
Memory Requirements
AdamW State
Per parameter:exp_avg: same shape as parameterexp_avg_sq: same shape as parameter- Total: 2x parameter memory
Muon State
Per parameter group (all same shape):momentum_buffer:(K, *shape)where K = number of params in groupsecond_momentum_buffer: factored, either(K, rows, 1)or(K, 1, cols)- Total: ~1x parameter memory (momentum) + small factored second moment
Recommended Hyperparameters
Based on the nanochat default configuration:dmodel_lr_scale factor.
References
- Muon: https://kellerjordan.github.io/posts/muon/
- Polar Express: https://arxiv.org/pdf/2505.16932
- NorMuon: https://arxiv.org/pdf/2510.05491
- AdamW: https://arxiv.org/abs/1711.05101
- modded-nanogpt: https://github.com/KellerJordan/modded-nanogpt
Related
GPT Architecture
Model architecture and parameter setup
Dataloader
Training data pipeline