Overview
Thenanochat.optim module provides combined optimizers that use Muon for 2D matrix parameters and AdamW for others. Two versions are available:
MuonAdamW- Single GPU versionDistMuonAdamW- Distributed multi-GPU version with ZeRO-2 style sharding
MuonAdamW
Combined optimizer for single GPU training.Parameters
List of parameter groups. Each group is a dict containing:Common fields:
params(list): List of parameterskind(str): Either'adamw'or'muon'
lr(float): Learning ratebetas(tuple): Coefficients for computing running averageseps(float): Term added to denominator for numerical stabilityweight_decay(float): Weight decay coefficient
lr(float): Learning ratemomentum(float): Momentum coefficientns_steps(int): Number of Newton-Schulz/Polar Express iterationsbeta2(float): Beta2 for second momentweight_decay(float): Weight decay coefficient
Methods
step
Notes
- AdamW: Uses fused AdamW optimizer step for non-matrix parameters (embeddings, scalars, biases)
- Muon: MomentUm Orthogonalized by Newton-schulz for 2D matrix parameters
- The Muon optimizer should not be used for:
- Embedding layers
- Final fully connected layer
- Any 0-D or 1-D parameters
- For 4D convolutional filters, flatten the last 3 dimensions before using Muon
Algorithm Details
Muon Step:- Nesterov momentum
- Polar Express orthogonalization (5 iterations)
- Variance reduction (NorMuon)
- Cautious weight decay + parameter update
- Weight decay (decoupled)
- Momentum update
- Bias correction
- Parameter update
DistMuonAdamW
Combined distributed optimizer for multi-GPU training.Parameters
List of parameter groups. Same format as
MuonAdamW.Additional requirement for Muon groups:- All params in a Muon group must have the same shape
Methods
step
Design Goals
- Overlap communication with computation (async ops)
- Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
- Batch small tensors into single comm ops where possible
Communication Pattern
Phase 1: Launch all async reduce ops- Kick off all reduce_scatter/all_reduce operations
- Don’t wait - let them run in background
- For each group: wait for its reduce, compute the update, launch gather
- Earlier gathers run while later computes happen
- Wait for all gathers to complete
- Copy updated params back to original tensors (Muon only)
AdamW Communication (ZeRO-2)
- Small params (<1024 elements): all_reduce gradients, update full param on each rank
- Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update only that slice, then all_gather the updated slices
- Optimizer state is sharded for large params
- Requires
param.shape[0]divisible byworld_size
Muon Communication
- Stack all K params into a single (K, *shape) tensor
- Divide K params across N ranks: each rank owns ceil(K/N) params
- reduce_scatter the stacked grads so each rank gets its chunk
- Each rank computes Muon update only for params it owns
- all_gather the updated params back to all ranks
- Optimizer state is sharded by chunk
- Zero-padding if K doesn’t divide evenly