Skip to main content

Overview

nanochat supports distributed training across multiple GPUs on a single node using PyTorch’s torchrun launcher. The codebase automatically detects distributed settings and handles data distribution, gradient synchronization, and checkpoint sharding.

Quick Start

Single GPU

Run training directly with Python:
python -m scripts.base_train --depth=12

Multi-GPU (8 GPUs)

Use torchrun to launch distributed training:
OMP_NUM_THREADS=1 torchrun --nproc_per_node=8 -m scripts.base_train --depth=26
The OMP_NUM_THREADS=1 environment variable prevents PyTorch from spawning too many CPU threads per process.

How It Works

Automatic DDP Setup

From nanochat/common.py, the compute_init() function automatically detects and configures distributed training:
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0  # Only rank 0 logs, saves checkpoints
Environment variables (set automatically by torchrun):
  • RANK: Global rank of this process (0 to world_size-1)
  • LOCAL_RANK: Local rank on this node
  • WORLD_SIZE: Total number of processes

Data Distribution

Each GPU processes a different subset of the data. The dataloader automatically shards data based on rank:
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(
    tokenizer, 
    args.device_batch_size,  # Per-GPU batch size
    args.max_seq_len, 
    split="train", 
    device=device,
    resume_state_dict=dataloader_resume_state_dict
)
From scripts/base_train.py:390-396:
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
assert total_batch_size % world_tokens_per_fwdbwd == 0
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
Total batch size = device_batch_size × max_seq_len × num_gpus × grad_accum_steps

Gradient Synchronization

PyTorch DDP automatically synchronizes gradients across GPUs:
  1. Forward pass: Each GPU computes loss on its local batch
  2. Backward pass: Gradients are computed locally
  3. Allreduce: DDP automatically averages gradients across all GPUs
  4. Optimizer step: Each GPU applies the same weight update
The model stays synchronized without explicit communication code.

Configuration

Batch Size Scaling

When using multiple GPUs, you have three options:

1. Keep total batch size constant (faster)

# 8 GPUs: reduce per-GPU batch from 32 to 4
torchrun --nproc_per_node=8 -m scripts.base_train \
    --depth=12 \
    --device-batch-size=4 \
    --total-batch-size=524288
Training finishes ~8x faster, same convergence.

2. Keep per-GPU batch constant (more compute)

# 8 GPUs: total batch grows 8x
torchrun --nproc_per_node=8 -m scripts.base_train \
    --depth=12 \
    --device-batch-size=32
    # total_batch_size auto-computed
nanochat will automatically scale learning rates and weight decay based on batch size.

3. Use gradient accumulation

If your GPU memory is limited:
# Accumulate gradients over 4 micro-steps
torchrun --nproc_per_node=8 -m scripts.base_train \
    --depth=26 \
    --device-batch-size=8 \
    --total-batch-size=1048576
From scripts/base_train.py:390-396, gradient accumulation is calculated automatically:
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")

Learning Rate Scaling

nanochat automatically scales learning rates with batch size following √(B/B_ref) scaling from AdamW theory:
batch_lr_scale = (total_batch_size / B_REF) ** 0.5
embedding_lr = args.embedding_lr * batch_lr_scale
matrix_lr = args.matrix_lr * batch_lr_scale
From scripts/base_train.py:283-289.

Weight Decay Scaling

Weight decay is scaled to maintain constant “effective epochs” (T_epoch framework):
weight_decay_scaled = args.weight_decay * math.sqrt(total_batch_size / B_REF) * (D_REF / target_tokens)
From scripts/base_train.py:297-299.

Logging and Checkpointing

Only rank 0 (the master process) performs I/O operations:

Logging

master_process = ddp_rank == 0
if master_process:
    print(f"Step {step} | Loss: {loss:.4f}")
    wandb_run.log({"loss": loss})
The print0() helper function only prints on rank 0.

Checkpointing

Model and metadata saved by rank 0 only, but optimizer state is sharded across all ranks:
save_checkpoint(
    checkpoint_dir,
    step,
    orig_model.state_dict(),  # Rank 0 only
    optimizer.state_dict(),   # All ranks (sharded)
    meta_data,                # Rank 0 only
    rank=ddp_rank,
)
From scripts/base_train.py:461-482. Each rank saves its own optimizer shard:
checkpoint_dir/
├── model_000500.pt          # Full model (rank 0 only)
├── meta_000500.json         # Metadata (rank 0 only)
├── optim_000500_rank0.pt    # Optimizer shard for rank 0
├── optim_000500_rank1.pt    # Optimizer shard for rank 1
├── ...
└── optim_000500_rank7.pt    # Optimizer shard for rank 7

Resuming Training

Resume from a checkpoint with --resume-from-step:
torchrun --nproc_per_node=8 -m scripts.base_train \
    --depth=26 \
    --resume-from-step=5000
The script automatically:
  1. Loads model weights on rank 0, broadcasts to all ranks
  2. Loads optimizer state from per-rank shards
  3. Loads dataloader state to resume from correct position
  4. Restores loop state (loss EMA, training time, etc.)
From scripts/base_train.py:154-158:
if resuming:
    model_data, optimizer_data, meta_data = load_checkpoint(
        checkpoint_dir, args.resume_from_step, device, 
        load_optimizer=True, rank=ddp_rank
    )
    model.load_state_dict(model_data, strict=True, assign=True)
    optimizer.load_state_dict(optimizer_data)

Performance Optimization

Communication Optimization

To minimize communication overhead:
  1. Use NCCL backend (automatic for CUDA)
  2. Set OMP_NUM_THREADS=1 to avoid CPU contention
  3. Maximize per-GPU batch size before using gradient accumulation
  4. Use high-bandwidth interconnect (NVLink, InfiniBand)

Measuring Efficiency

Monitor these metrics in your training logs:
flops_per_sec = num_flops_per_token * total_batch_size / dt
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
print0(f"MFU: {mfu:.2f}%")  # Model FLOPS Utilization
From scripts/base_train.py:522-524. Good MFU on H100:
  • 30-40% without FP8 (typical)
  • 40-50% with FP8 (current speedrun achieves ~45%)

Single GPU Fallback

The same code runs on a single GPU without any changes:
# Single GPU - no torchrun needed
python -m scripts.base_train --depth=12
Gradient accumulation automatically increases to match the target total batch size:
# 1 GPU: accumulate 8x more gradients to match 8-GPU throughput
grad_accum_steps = total_batch_size // (tokens_per_fwdbwd * 1)
Training takes 8x longer but produces identical results.

Example: 8XH100 Speedrun

The GPT-2 speedrun configuration from runs/speedrun.sh:
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train \
    --depth=26 \
    --run="speedrun" \
    --model-tag="d26" \
    --fp8 \
    --fp8-recipe=tensorwise \
    --device-batch-size=32 \
    --max-seq-len=2048 \
    --eval-every=500 \
    --core-metric-every=2000 \
    --sample-every=2000 \
    --save-every=-1
Key settings:
  • 8 GPUs × 32 batch × 2048 tokens = 524,288 tokens per step (before grad accum)
  • Total batch auto-computed to ~1M tokens (via scaling laws)
  • Gradient accumulation: ~2 steps
  • Completes in ~2.9 hours on 8XH100

Troubleshooting

NCCL Timeout

If you see NCCL timeouts:
export NCCL_TIMEOUT=3600  # Increase timeout to 1 hour
export NCCL_DEBUG=INFO    # Enable debug logging

Out of Memory

Reduce --device-batch-size:
# Try 32 -> 16 -> 8 -> 4 -> 2 -> 1
torchrun --nproc_per_node=8 -m scripts.base_train \
    --depth=26 \
    --device-batch-size=16  # Reduced from 32

Uneven GPU Utilization

If one GPU is much slower:
  • Check for hardware issues (nvidia-smi)
  • Ensure data is distributed evenly (automatic in nanochat)
  • Verify all GPUs are the same model

Further Reading

  • PyTorch DDP Tutorial
  • PyTorch torchrun
  • nanochat/common.py - DDP initialization code
  • nanochat/dataloader.py - Distributed data loading
  • nanochat/optim.py - Distributed optimizer (ZeRO-1 style sharding)

Build docs developers (and LLMs) love