Overview
Pretrain a transformer language model using the nanochat framework. The training script implements:- Automatic compute-optimal hyperparameter scaling (muP-style)
- Mixed Muon/AdamW optimization
- Learning rate scheduling with warmup/warmdown
- Distributed training with DDP
- Optional FP8 training (H100+ GPUs)
Quick Start
Single GPU:The Depth Dial
The most important parameter is--depth, which controls model size. All hyperparameters scale automatically:
--depth parameter determines:
- Model dimension:
model_dim = depth × aspect_ratio(rounded to multiple ofhead_dim) - Training tokens: Automatically computed to maintain optimal data:param ratio
- Batch size: Scales as
D^0.383following Power Lines paper - Learning rates: Scale as
√(B/B_ref)for batch size changes - Weight decay: Scales to maintain constant
T_epoch(see paper)
Core Parameters
Logging
Wandb run name. Set to “dummy” to disable wandb logging.
Model Architecture
Depth of the transformer (number of layers). This is the primary model size dial.
Model dimension = depth × aspect_ratio. The default gives clean muP scaling.
Target dimension per attention head. Model dim is rounded to nearest multiple.
Maximum context length (sequence length).
Sliding window attention pattern tiled across layers:
L= full context attentionS= half-context sliding window
--window-pattern=L for better performance.Training Horizon
Only one of these is used (in order of precedence):Explicit number of optimization steps. Overrides automatic calculation.
Calculate num_iterations to reach target FLOPs. Used for scaling laws analysis.
Calculate num_iterations to maintain this data:param ratio. This is the default mode.Reference: Chinchilla uses 20:1, nanochat defaults to 10.5:1 (compute-optimal from scaling laws).
Batch Size and Optimization
Per-device batch size. Reduce to 16, 8, or 4 if you run out of VRAM.
Total batch size in tokens across all devices.
-1 = auto-compute optimal batch size using scaling law B_opt ∝ D^0.383Learning Rates
The optimizer uses different learning rates for different parameter groups:Learning rate for transformer weight matrices (uses Muon optimizer).
Learning rate for input embedding (uses Adam).
Learning rate for output unembedding/lm_head (uses Adam).
Learning rate for scalar parameters (resid_lambdas, x0_lambdas).
Weight decay for Muon optimizer. Auto-scaled to maintain constant T_epoch.
Adam Hyperparameters
Adam beta1 (momentum) for embedding/unembedding parameters.
Adam beta2 (RMSprop) for embedding/unembedding parameters.
Learning Rate Schedule
Fraction of iterations for linear LR warmup. 0.0 = no warmup.
Fraction of iterations for linear LR warmdown. 0.5 = last half of training.
Final LR as fraction of initial LR. 0.0 = decay to zero.
Evaluation
Evaluate validation bits-per-byte every N steps. -1 to disable.
Number of tokens for validation evaluation (default: 40 × 524288 ≈ 21M).
Evaluate CORE metric (downstream tasks) every N steps. -1 to disable.
Max examples per task for CORE metric evaluation.
Generate text samples every N steps. -1 to disable.
Save checkpoint every N steps. -1 = only save at the end.
Advanced
Device type: cuda, cpu, or mps. Empty string = autodetect.
Enable FP8 training (requires H100+ GPU and torchao). Significantly faster.
FP8 scaling recipe:
tensorwise (faster, recommended) or rowwise (more accurate).Resume training from this checkpoint step. -1 = don’t resume.
Override model tag for checkpoint directory. Default:
d{depth} (e.g., “d12”).Scaling Laws and Auto-Tuning
The training script implements automatic hyperparameter scaling based on empirical scaling laws:Reference Model (d12)
All scaling is anchored to depth=12 as the reference:- D_ref: Optimal training tokens for d12
- B_ref: Optimal batch size = 524,288 tokens
Optimal Training Tokens
Optimal Batch Size
Follows the Power Lines paper (arXiv:2505.13738):Learning Rate Scaling
When batch size changes from reference:Weight Decay Scaling
Follows T_epoch framework (arXiv:2405.13698):Example Workflows
Train a 100M parameter model (d12)
Train a 300M parameter model (d20)
Override batch size and training length
Train with FP8 (H100 GPU)
Resume training from checkpoint
Output
Checkpoints are saved to$NANOCHAT_BASE_DIR/base_checkpoints/{model_tag}/:
step_{N}_model.pt- Model weightsstep_{N}_optimizer.pt- Optimizer statestep_{N}_meta.json- Metadata (config, validation loss, etc.)
Monitoring
Key metrics logged to console and wandb:- train/loss - Training loss (cross-entropy)
- val/bpb - Validation bits per byte (better than loss, invariant to vocab size)
- train/mfu - Model FLOPs Utilization (% of peak GPU performance)
- train/tok_per_sec - Throughput in tokens/second
- core_metric - Performance on downstream tasks (ARC, HellaSwag, MMLU, etc.)
Performance Tips
- VRAM management: Reduce
--device-batch-sizeif OOM - Flash Attention 3: Use Hopper GPUs (H100) for best efficiency
- Sliding windows: Use
--window-pattern=Lif not using FA3 - FP8 training: Enable with
--fp8on H100+ for 1.5-2× speedup - Gradient accumulation: Automatically computed to hit
--total-batch-size