Skip to main content

Overview

Nanochat uses MuonAdamW, a hybrid optimizer that combines two different optimization algorithms:
  • Muon (Momentum Orthogonalized by Newton-schulz): For 2D matrix parameters (weights in attention and MLP layers)
  • AdamW: For embeddings, language model head, and scalar parameters
This design exploits the different geometric properties of these parameter types for better training dynamics.

Architecture-Specific Parameter Groups

The GPT model’s optimizer setup creates distinct parameter groups with different learning rates:

AdamW Groups

GroupParametersBase LRBeta1Beta2Weight Decay
Unembeddinglm_head0.0040.80.950.0
Embeddingswte, value_embeds0.20.80.950.0
Resid scalarsresid_lambdas0.0050.80.950.0
X0 scalarsx0_lambdas0.50.960.950.0
Learning Rate Scaling: All AdamW learning rates are scaled by ∝1/√(n_embd/768) to maintain consistent behavior across model sizes.

Muon Groups

Matrix parameters are grouped by shape and optimized together:
ParametersBase LRMomentumNS StepsBeta2Weight Decay
All 2D matrices0.020.9550.95configurable
Matrices include:
  • Attention: c_q, c_k, c_v, c_proj, ve_gate
  • MLP: c_fc, c_proj
Learning Rate Scaling: Muon LR is scaled by max(1.0, rows/cols)^0.5 per group to account for matrix aspect ratio. Reference: gpt.py:348-386

Muon Optimizer Details

Algorithm Overview

Muon performs momentum-based optimization followed by orthogonalization:
  1. Nesterov Momentum: Apply momentum to gradients
  2. Polar Express: Orthogonalize the update using Newton-Schulz iteration
  3. Variance Reduction (NorMuon): Normalize per-neuron update scales
  4. Cautious Weight Decay: Apply decay only when update and parameter agree in sign

Step 1: Nesterov Momentum

momentum_buffer.lerp_(stacked_grads, 1 - momentum)
g = stacked_grads.lerp_(momentum_buffer, momentum)
  • Momentum coefficient: 0.95
  • Uses Nesterov-style lookahead
  • Accumulated per parameter
Reference: optim.py:110-112

Step 2: Polar Express Orthogonalization

Replaces Newton-Schulz with Polar Express for better convergence:
X = g / (g.norm() * 1.02 + 1e-6)  # Normalize
for a, b, c in polar_express_coeffs[:5]:
    A = X @ X.mT  # or X.mT @ X for tall matrices
    B = b * A + c * (A @ A)
    X = a * X + B @ X
  • 5 iterations (configurable via ns_steps)
  • Automatically handles tall vs. wide matrices
  • Computed in bfloat16 for efficiency
  • Coefficients optimized for safety_factor=0.02, cushion=2
The result is approximately U S' V^T where S' has diagonal entries ~ Uniform(0.5, 1.5), which empirically works as well as true UV^T orthogonalization. Reference: optim.py:115-127, optim.py:80-88

Step 3: Variance Reduction (NorMuon)

Normalizes update magnitudes across neurons/columns:
# Compute per-neuron variance
v_mean = g.square().mean(dim=red_dim, keepdim=True)

# Track with EMA (beta2=0.95)
second_momentum_buffer.lerp_(v_mean, 1 - beta2)

# Compute adaptive step size
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()

# Scale to preserve total norm
scaled_sq_sum = (v_mean * red_dim_size) * step_size.square()
v_norm_new = scaled_sq_sum.sum().sqrt()
final_scale = step_size * (v_norm / v_norm_new)

g = g * final_scale
This addresses the issue that Muon’s output has non-uniform scales across neurons after orthogonalization.
  • Second moment tracked per row or column (factored, not full matrix)
  • Beta2: 0.95
  • Preserves overall gradient norm while equalizing per-neuron magnitudes
Reference: optim.py:129-140

Step 4: Cautious Weight Decay

Applies weight decay only when the update and parameter agree in sign:
mask = (g * stacked_params) >= 0
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
This prevents the optimizer from increasing parameter magnitudes when the update direction already points away. Reference: optim.py:142-146

AdamW Optimizer Details

Standard AdamW with decoupled weight decay:
# Weight decay (decoupled)
p.mul_(1 - lr * wd)

# Update exponential moving averages
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.lerp_(grad.square(), 1 - beta2)

# Bias correction
bias1 = 1 - beta1 ** step
bias2 = 1 - beta2 ** step

# Parameter update
denom = (exp_avg_sq / bias2).sqrt() + eps
step_size = lr / bias1
p.add_(exp_avg / denom, alpha=-step_size)
Reference: optim.py:20-49

Implementations

MuonAdamW (Single GPU)

Baseline implementation for single GPU training:
  • Each parameter optimized individually (AdamW)
  • Matrix parameters stacked by shape for efficient Muon steps
  • No distributed communication
  • Used for debugging and small-scale experiments
Reference: optim.py:152-291

DistMuonAdamW (Multi-GPU)

Optimized for distributed training with ZeRO-2 style sharding: AdamW Communication:
  • Small params (<1024 elements): all_reduce gradients, replicate state
  • Large params: reduce_scatter gradients, shard state across ranks, all_gather updates
Muon Communication:
  • Stack all parameters in group into single tensor: (K, *shape)
  • Divide K parameters across N ranks: each owns ceil(K/N) parameters
  • reduce_scatter stacked gradients → each rank gets its chunk
  • Each rank computes Muon update for its chunk only
  • all_gather updated parameters back to all ranks
  • Optimizer state sharded by chunk (momentum_buffer, second_momentum_buffer)
3-Phase Async Pattern:
  1. Launch all async reduce operations (don’t wait)
  2. For each group: wait for reduce → compute update → launch gather
  3. Wait for all gathers → copy parameters back
This maximizes overlap between communication and computation. Reference: optim.py:297-533

Fused Kernels

Both optimizers use @torch.compile fused kernels for efficiency:

adamw_step_fused

Single compiled graph for:
  • Weight decay
  • Momentum update
  • Bias correction
  • Parameter update
Reference: optim.py:20-49

muon_step_fused

Single compiled graph for:
  • Nesterov momentum
  • Polar Express (5 iterations)
  • Variance reduction
  • Cautious update
Reference: optim.py:90-146

0-D CPU Tensors

Hyperparameters (lr, beta1, beta2, etc.) are stored as 0-D CPU tensors:
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
This avoids recompilation when values change (e.g., during LR scheduling). Reference: optim.py:182-192, optim.py:358-367

Parameter Grouping Strategy

Muon requires all parameters in a group to have the same shape. The GPT model achieves this naturally:
for shape in sorted({p.shape for p in matrix_params}):
    group_params = [p for p in matrix_params if p.shape == shape]
    param_groups.append(dict(
        kind='muon', params=group_params, lr=matrix_lr, ...
    ))
Typical shapes in a GPT model:
  • (n_embd, n_embd): attention projections
  • (n_embd, 4*n_embd): MLP up-projection
  • (4*n_embd, n_embd): MLP down-projection
Reference: gpt.py:375-380

Memory Requirements

AdamW State

Per parameter:
  • exp_avg: same shape as parameter
  • exp_avg_sq: same shape as parameter
  • Total: 2x parameter memory

Muon State

Per parameter group (all same shape):
  • momentum_buffer: (K, *shape) where K = number of params in group
  • second_momentum_buffer: factored, either (K, rows, 1) or (K, 1, cols)
  • Total: ~1x parameter memory (momentum) + small factored second moment
Muon is more memory-efficient than AdamW for 2D matrices. Based on the nanochat default configuration:
optimizer = model.setup_optimizer(
    unembedding_lr=0.004,
    embedding_lr=0.2,
    matrix_lr=0.02,
    weight_decay=0.0,        # or 0.1 for larger models
    adam_betas=(0.8, 0.95),
    scalar_lr=0.5,
)
These values are tuned for 768-dimensional models and scale automatically via the dmodel_lr_scale factor.

References

GPT Architecture

Model architecture and parameter setup

Dataloader

Training data pipeline

Build docs developers (and LLMs) love