Skip to main content

Overview

Pretrain a transformer language model using the nanochat framework. The training script implements:
  • Automatic compute-optimal hyperparameter scaling (muP-style)
  • Mixed Muon/AdamW optimization
  • Learning rate scheduling with warmup/warmdown
  • Distributed training with DDP
  • Optional FP8 training (H100+ GPUs)

Quick Start

Single GPU:
python -m scripts.base_train
Multi-GPU (8 GPUs):
torchrun --nproc_per_node=8 -m scripts.base_train
CPU/Macbook (tiny model for testing):
python -m scripts.base_train \
  --depth=4 \
  --max-seq-len=512 \
  --device-batch-size=1 \
  --total-batch-size=512 \
  --num-iterations=20

The Depth Dial

The most important parameter is --depth, which controls model size. All hyperparameters scale automatically:
python -m scripts.base_train --depth=12  # ~100M params
python -m scripts.base_train --depth=20  # ~300M params
python -m scripts.base_train --depth=28  # ~600M params
The --depth parameter determines:
  • Model dimension: model_dim = depth × aspect_ratio (rounded to multiple of head_dim)
  • Training tokens: Automatically computed to maintain optimal data:param ratio
  • Batch size: Scales as D^0.383 following Power Lines paper
  • Learning rates: Scale as √(B/B_ref) for batch size changes
  • Weight decay: Scales to maintain constant T_epoch (see paper)

Core Parameters

Logging

--run
str
default:"dummy"
Wandb run name. Set to “dummy” to disable wandb logging.

Model Architecture

--depth
int
default:"20"
Depth of the transformer (number of layers). This is the primary model size dial.
--aspect-ratio
int
default:"64"
Model dimension = depth × aspect_ratio. The default gives clean muP scaling.
--head-dim
int
default:"128"
Target dimension per attention head. Model dim is rounded to nearest multiple.
--max-seq-len
int
default:"2048"
Maximum context length (sequence length).
--window-pattern
str
default:"SSSL"
Sliding window attention pattern tiled across layers:
  • L = full context attention
  • S = half-context sliding window
Example: “SSSL” means layers use [half, half, half, full, half, half, half, full, …]Note: Without Flash Attention 3, use --window-pattern=L for better performance.

Training Horizon

Only one of these is used (in order of precedence):
--num-iterations
int
default:"-1"
Explicit number of optimization steps. Overrides automatic calculation.
--target-flops
float
default:"-1"
Calculate num_iterations to reach target FLOPs. Used for scaling laws analysis.
--target-param-data-ratio
float
default:"10.5"
Calculate num_iterations to maintain this data:param ratio. This is the default mode.Reference: Chinchilla uses 20:1, nanochat defaults to 10.5:1 (compute-optimal from scaling laws).

Batch Size and Optimization

--device-batch-size
int
default:"32"
Per-device batch size. Reduce to 16, 8, or 4 if you run out of VRAM.
--total-batch-size
int
default:"-1"
Total batch size in tokens across all devices.-1 = auto-compute optimal batch size using scaling law B_opt ∝ D^0.383

Learning Rates

The optimizer uses different learning rates for different parameter groups:
--matrix-lr
float
default:"0.02"
Learning rate for transformer weight matrices (uses Muon optimizer).
--embedding-lr
float
default:"0.3"
Learning rate for input embedding (uses Adam).
--unembedding-lr
float
default:"0.004"
Learning rate for output unembedding/lm_head (uses Adam).
--scalar-lr
float
default:"0.5"
Learning rate for scalar parameters (resid_lambdas, x0_lambdas).
--weight-decay
float
default:"0.2"
Weight decay for Muon optimizer. Auto-scaled to maintain constant T_epoch.

Adam Hyperparameters

--adam-beta1
float
default:"0.8"
Adam beta1 (momentum) for embedding/unembedding parameters.
--adam-beta2
float
default:"0.95"
Adam beta2 (RMSprop) for embedding/unembedding parameters.

Learning Rate Schedule

--warmup-ratio
float
default:"0.0"
Fraction of iterations for linear LR warmup. 0.0 = no warmup.
--warmdown-ratio
float
default:"0.5"
Fraction of iterations for linear LR warmdown. 0.5 = last half of training.
--final-lr-frac
float
default:"0.0"
Final LR as fraction of initial LR. 0.0 = decay to zero.

Evaluation

--eval-every
int
default:"250"
Evaluate validation bits-per-byte every N steps. -1 to disable.
--eval-tokens
int
default:"20971520"
Number of tokens for validation evaluation (default: 40 × 524288 ≈ 21M).
--core-metric-every
int
default:"2000"
Evaluate CORE metric (downstream tasks) every N steps. -1 to disable.
--core-metric-max-per-task
int
default:"500"
Max examples per task for CORE metric evaluation.
--sample-every
int
default:"2000"
Generate text samples every N steps. -1 to disable.
--save-every
int
default:"-1"
Save checkpoint every N steps. -1 = only save at the end.

Advanced

--device-type
str
default:""
Device type: cuda, cpu, or mps. Empty string = autodetect.
--fp8
bool
default:"false"
Enable FP8 training (requires H100+ GPU and torchao). Significantly faster.
--fp8-recipe
str
default:"tensorwise"
FP8 scaling recipe: tensorwise (faster, recommended) or rowwise (more accurate).
--resume-from-step
int
default:"-1"
Resume training from this checkpoint step. -1 = don’t resume.
--model-tag
str
default:"None"
Override model tag for checkpoint directory. Default: d{depth} (e.g., “d12”).

Scaling Laws and Auto-Tuning

The training script implements automatic hyperparameter scaling based on empirical scaling laws:

Reference Model (d12)

All scaling is anchored to depth=12 as the reference:
  • D_ref: Optimal training tokens for d12
  • B_ref: Optimal batch size = 524,288 tokens

Optimal Training Tokens

optimal_tokens = target_param_data_ratio × scaling_params
# scaling_params = transformer_matrices + lm_head parameters
Default ratio of 10.5:1 is compute-optimal (derived from scaling laws experiments).

Optimal Batch Size

Follows the Power Lines paper (arXiv:2505.13738):
B_opt = B_ref × (D / D_ref)^0.383
Rounded to nearest power of 2 for efficiency.

Learning Rate Scaling

When batch size changes from reference:
lr_scale = √(B / B_ref)
Applied to all learning rates (embedding, unembedding, matrix, scalar).

Weight Decay Scaling

Follows T_epoch framework (arXiv:2405.13698):
λ = λ_ref × √(B / B_ref) × (D_ref / D)
Maintains constant effective regularization strength across model sizes.

Example Workflows

Train a 100M parameter model (d12)

torchrun --nproc_per_node=8 -m scripts.base_train \
  --depth=12 \
  --run=my_d12_run
Everything scales automatically to compute-optimal settings.

Train a 300M parameter model (d20)

torchrun --nproc_per_node=8 -m scripts.base_train \
  --depth=20 \
  --run=my_d20_run

Override batch size and training length

torchrun --nproc_per_node=8 -m scripts.base_train \
  --depth=16 \
  --total-batch-size=262144 \
  --num-iterations=10000 \
  --run=custom_run

Train with FP8 (H100 GPU)

torchrun --nproc_per_node=8 -m scripts.base_train \
  --depth=20 \
  --fp8 \
  --fp8-recipe=tensorwise \
  --run=fp8_run

Resume training from checkpoint

torchrun --nproc_per_node=8 -m scripts.base_train \
  --depth=20 \
  --resume-from-step=5000 \
  --model-tag=d20 \
  --run=resumed_run

Output

Checkpoints are saved to $NANOCHAT_BASE_DIR/base_checkpoints/{model_tag}/:
  • step_{N}_model.pt - Model weights
  • step_{N}_optimizer.pt - Optimizer state
  • step_{N}_meta.json - Metadata (config, validation loss, etc.)

Monitoring

Key metrics logged to console and wandb:
  • train/loss - Training loss (cross-entropy)
  • val/bpb - Validation bits per byte (better than loss, invariant to vocab size)
  • train/mfu - Model FLOPs Utilization (% of peak GPU performance)
  • train/tok_per_sec - Throughput in tokens/second
  • core_metric - Performance on downstream tasks (ARC, HellaSwag, MMLU, etc.)

Performance Tips

  1. VRAM management: Reduce --device-batch-size if OOM
  2. Flash Attention 3: Use Hopper GPUs (H100) for best efficiency
  3. Sliding windows: Use --window-pattern=L if not using FA3
  4. FP8 training: Enable with --fp8 on H100+ for 1.5-2× speedup
  5. Gradient accumulation: Automatically computed to hit --total-batch-size

Build docs developers (and LLMs) love