Overview
nanochat supports distributed training across multiple GPUs on a single node using PyTorch’storchrun launcher. The codebase automatically detects distributed settings and handles data distribution, gradient synchronization, and checkpoint sharding.
Quick Start
Single GPU
Run training directly with Python:Multi-GPU (8 GPUs)
Usetorchrun to launch distributed training:
OMP_NUM_THREADS=1 environment variable prevents PyTorch from spawning too many CPU threads per process.
How It Works
Automatic DDP Setup
Fromnanochat/common.py, the compute_init() function automatically detects and configures distributed training:
torchrun):
RANK: Global rank of this process (0 to world_size-1)LOCAL_RANK: Local rank on this nodeWORLD_SIZE: Total number of processes
Data Distribution
Each GPU processes a different subset of the data. The dataloader automatically shards data based on rank:scripts/base_train.py:390-396:
device_batch_size × max_seq_len × num_gpus × grad_accum_steps
Gradient Synchronization
PyTorch DDP automatically synchronizes gradients across GPUs:- Forward pass: Each GPU computes loss on its local batch
- Backward pass: Gradients are computed locally
- Allreduce: DDP automatically averages gradients across all GPUs
- Optimizer step: Each GPU applies the same weight update
Configuration
Batch Size Scaling
When using multiple GPUs, you have three options:1. Keep total batch size constant (faster)
2. Keep per-GPU batch constant (more compute)
3. Use gradient accumulation
If your GPU memory is limited:scripts/base_train.py:390-396, gradient accumulation is calculated automatically:
Learning Rate Scaling
nanochat automatically scales learning rates with batch size following √(B/B_ref) scaling from AdamW theory:scripts/base_train.py:283-289.
Weight Decay Scaling
Weight decay is scaled to maintain constant “effective epochs” (T_epoch framework):scripts/base_train.py:297-299.
Logging and Checkpointing
Only rank 0 (the master process) performs I/O operations:Logging
print0() helper function only prints on rank 0.
Checkpointing
Model and metadata saved by rank 0 only, but optimizer state is sharded across all ranks:scripts/base_train.py:461-482. Each rank saves its own optimizer shard:
Resuming Training
Resume from a checkpoint with--resume-from-step:
- Loads model weights on rank 0, broadcasts to all ranks
- Loads optimizer state from per-rank shards
- Loads dataloader state to resume from correct position
- Restores loop state (loss EMA, training time, etc.)
scripts/base_train.py:154-158:
Performance Optimization
Communication Optimization
To minimize communication overhead:- Use NCCL backend (automatic for CUDA)
- Set
OMP_NUM_THREADS=1to avoid CPU contention - Maximize per-GPU batch size before using gradient accumulation
- Use high-bandwidth interconnect (NVLink, InfiniBand)
Measuring Efficiency
Monitor these metrics in your training logs:scripts/base_train.py:522-524. Good MFU on H100:
- 30-40% without FP8 (typical)
- 40-50% with FP8 (current speedrun achieves ~45%)
Single GPU Fallback
The same code runs on a single GPU without any changes:Example: 8XH100 Speedrun
The GPT-2 speedrun configuration fromruns/speedrun.sh:
- 8 GPUs × 32 batch × 2048 tokens = 524,288 tokens per step (before grad accum)
- Total batch auto-computed to ~1M tokens (via scaling laws)
- Gradient accumulation: ~2 steps
- Completes in ~2.9 hours on 8XH100
Troubleshooting
NCCL Timeout
If you see NCCL timeouts:Out of Memory
Reduce--device-batch-size:
Uneven GPU Utilization
If one GPU is much slower:- Check for hardware issues (
nvidia-smi) - Ensure data is distributed evenly (automatic in nanochat)
- Verify all GPUs are the same model
Further Reading
- PyTorch DDP Tutorial
- PyTorch torchrun
nanochat/common.py- DDP initialization codenanochat/dataloader.py- Distributed data loadingnanochat/optim.py- Distributed optimizer (ZeRO-1 style sharding)