Skip to main content

Overview

The nanochat.optim module provides combined optimizers that use Muon for 2D matrix parameters and AdamW for others. Two versions are available:
  • MuonAdamW - Single GPU version
  • DistMuonAdamW - Distributed multi-GPU version with ZeRO-2 style sharding

MuonAdamW

Combined optimizer for single GPU training.
class MuonAdamW(param_groups: list[dict])

Parameters

param_groups
list[dict]
required
List of parameter groups. Each group is a dict containing:Common fields:
  • params (list): List of parameters
  • kind (str): Either 'adamw' or 'muon'
For AdamW groups:
  • lr (float): Learning rate
  • betas (tuple): Coefficients for computing running averages
  • eps (float): Term added to denominator for numerical stability
  • weight_decay (float): Weight decay coefficient
For Muon groups:
  • lr (float): Learning rate
  • momentum (float): Momentum coefficient
  • ns_steps (int): Number of Newton-Schulz/Polar Express iterations
  • beta2 (float): Beta2 for second moment
  • weight_decay (float): Weight decay coefficient

Methods

step

@torch.no_grad()
def step()
Performs a single optimization step.

Notes

  • AdamW: Uses fused AdamW optimizer step for non-matrix parameters (embeddings, scalars, biases)
  • Muon: MomentUm Orthogonalized by Newton-schulz for 2D matrix parameters
  • The Muon optimizer should not be used for:
    • Embedding layers
    • Final fully connected layer
    • Any 0-D or 1-D parameters
  • For 4D convolutional filters, flatten the last 3 dimensions before using Muon

Algorithm Details

Muon Step:
  1. Nesterov momentum
  2. Polar Express orthogonalization (5 iterations)
  3. Variance reduction (NorMuon)
  4. Cautious weight decay + parameter update
AdamW Step:
  1. Weight decay (decoupled)
  2. Momentum update
  3. Bias correction
  4. Parameter update

DistMuonAdamW

Combined distributed optimizer for multi-GPU training.
class DistMuonAdamW(param_groups: list[dict])

Parameters

param_groups
list[dict]
required
List of parameter groups. Same format as MuonAdamW.Additional requirement for Muon groups:
  • All params in a Muon group must have the same shape

Methods

step

@torch.no_grad()
def step()
Performs a single distributed optimization step with 3-phase async communication.

Design Goals

  • Overlap communication with computation (async ops)
  • Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
  • Batch small tensors into single comm ops where possible

Communication Pattern

Phase 1: Launch all async reduce ops
  • Kick off all reduce_scatter/all_reduce operations
  • Don’t wait - let them run in background
Phase 2: Wait for reduces, compute updates, launch gathers
  • For each group: wait for its reduce, compute the update, launch gather
  • Earlier gathers run while later computes happen
Phase 3: Wait for gathers, copy back
  • Wait for all gathers to complete
  • Copy updated params back to original tensors (Muon only)

AdamW Communication (ZeRO-2)

  • Small params (<1024 elements): all_reduce gradients, update full param on each rank
  • Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update only that slice, then all_gather the updated slices
  • Optimizer state is sharded for large params
  • Requires param.shape[0] divisible by world_size

Muon Communication

  • Stack all K params into a single (K, *shape) tensor
  • Divide K params across N ranks: each rank owns ceil(K/N) params
  • reduce_scatter the stacked grads so each rank gets its chunk
  • Each rank computes Muon update only for params it owns
  • all_gather the updated params back to all ranks
  • Optimizer state is sharded by chunk
  • Zero-padding if K doesn’t divide evenly

Example

import torch
from nanochat.optim import DistMuonAdamW

# Separate parameters into AdamW and Muon groups
adamw_params = []  # embeddings, scalars, 1D params
muon_params = []   # 2D matrix params (same shape)

optimizer = DistMuonAdamW([
    {
        'params': adamw_params,
        'kind': 'adamw',
        'lr': 3e-4,
        'betas': (0.9, 0.999),
        'eps': 1e-8,
        'weight_decay': 0.01
    },
    {
        'params': muon_params,
        'kind': 'muon',
        'lr': 0.02,
        'momentum': 0.95,
        'ns_steps': 5,
        'beta2': 0.95,
        'weight_decay': 0.01
    }
])

# Training loop
for batch in dataloader:
    loss = model(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Build docs developers (and LLMs) love