Overview
The training module provides a complete implementation for training GPT models with support for:
Single GPU and distributed data parallel (DDP) training
Mixed precision training (float16/bfloat16)
Gradient accumulation
Learning rate scheduling with warmup and cosine decay
Checkpointing and resumption
WandB integration for experiment tracking
Training modes
You can run the training script in multiple configurations:
Single GPU
DDP on 4 GPUs (single node)
DDP on multiple nodes
python train.py --batch_size=32 --compile=False
If your cluster does not have Infiniband, prepend NCCL_IB_DISABLE=1 to the commands.
Key functions
get_batch
Loads a batch of training or validation data using memory-mapped files.
Data split to load from: 'train' or 'val'
Returns: Tuple of (x, y)
x: Input token sequences of shape (batch_size, block_size)
y: Target token sequences of shape (batch_size, block_size), shifted by one position
Show Implementation details
The function:
Creates a new numpy memmap for each batch to avoid memory leaks
Randomly samples starting positions from the dataset
Extracts sequences of length block_size
Creates input (x) and target (y) sequences where y is x shifted by one token
Pins memory and transfers to GPU asynchronously if using CUDA
Data is expected to be stored as binary files:
data/{dataset}/train.bin
data/{dataset}/val.bin
Each file contains uint16 token indices.
estimate_loss
@torch.no_grad ()
estimate_loss()
Computes accurate loss estimates over multiple batches for both training and validation splits.
Returns: Dictionary with keys 'train' and 'val', each containing the mean loss
Number of iterations to average over (configured globally)
Show Why multiple batches?
Evaluating on multiple batches provides a more stable and accurate loss estimate compared to a single batch, which can be noisy. The function:
Sets model to eval mode
Runs inference on eval_iters batches for each split
Averages the losses
Returns model to train mode
This is called periodically during training (every eval_interval steps) to monitor progress.
get_lr
Computes the learning rate for a given iteration using cosine decay with linear warmup.
Returns: float - Learning rate for this iteration
Show Learning rate schedule
The schedule has three phases:
Linear warmup (iterations 0 to warmup_iters):
lr = learning_rate * (it + 1) / (warmup_iters + 1)
Cosine decay (iterations warmup_iters to lr_decay_iters):
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
coeff = 0.5 * (1.0 + cos(π * decay_ratio))
lr = min_lr + coeff * (learning_rate - min_lr)
Constant minimum (iterations > lr_decay_iters):
This follows the approach from the Chinchilla paper.
Training loop structure
The main training loop performs the following steps:
1. Learning rate scheduling
lr = get_lr(iter_num) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
param_group[ 'lr' ] = lr
2. Periodic evaluation
if iter_num % eval_interval == 0 and master_process:
losses = estimate_loss()
print ( f "step { iter_num } : train loss { losses[ 'train' ] :.4f} , val loss { losses[ 'val' ] :.4f} " )
# Save checkpoint if validation loss improved or always_save_checkpoint=True
3. Forward and backward pass
With gradient accumulation to simulate larger batch sizes:
for micro_step in range (gradient_accumulation_steps):
if ddp:
# Only sync gradients on the last micro step
model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1 )
with ctx: # Automatic mixed precision context
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps
# Async prefetch next batch
X, Y = get_batch( 'train' )
# Backward pass with gradient scaling for fp16
scaler.scale(loss).backward()
4. Gradient clipping and optimizer step
if grad_clip != 0.0 :
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad( set_to_none = True )
5. Logging
if iter_num % log_interval == 0 and master_process:
lossf = loss.item() * gradient_accumulation_steps
if local_iter_num >= 5 :
mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
running_mfu = mfu if running_mfu == - 1.0 else 0.9 * running_mfu + 0.1 * mfu
print ( f "iter { iter_num } : loss { lossf :.4f} , time { dt * 1000 :.2f} ms, mfu { running_mfu * 100 :.2f} %" )
Configuration parameters
I/O settings
Directory for saving checkpoints
How often to evaluate on val set and save checkpoints
How often to log training metrics
Number of iterations for loss estimation
If True, exit after first evaluation (useful for testing)
If True, save checkpoint after each eval even if val loss didn’t improve
Initialization mode: 'scratch', 'resume', or a GPT-2 variant ('gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl')
Data settings
dataset
str
default: "'openwebtext'"
Name of dataset (must have corresponding data// directory)
gradient_accumulation_steps
Accumulate gradients over this many steps to simulate larger batches
Micro-batch size (per GPU if using DDP)
Context length for training sequences
Model architecture
Number of transformer layers
Number of attention heads
Dropout rate (0.0 for pretraining, 0.1+ for finetuning)
Use bias in Linear and LayerNorm layers
Optimizer settings
Total number of training iterations
Gradient clipping threshold (0.0 to disable)
Learning rate decay
Enable learning rate decay
Number of warmup iterations
Iterations for learning rate decay (should be ~= max_iters)
Minimum learning rate (should be ~= learning_rate/10)
System settings
Device to train on: 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'mps', etc.
dtype
str
default: "'bfloat16' or 'float16'"
Data type for training: 'float32', 'bfloat16', or 'float16'. Automatically selects bfloat16 if supported.
Use PyTorch 2.0 compilation for faster training
DDP settings
DDP backend: 'nccl' (recommended for CUDA) or 'gloo'
Checkpointing
Checkpoints are saved to {out_dir}/ckpt.pt and contain:
{
'model' : model.state_dict(),
'optimizer' : optimizer.state_dict(),
'model_args' : model_args,
'iter_num' : iter_num,
'best_val_loss' : best_val_loss,
'config' : config,
}
To resume training from a checkpoint, set init_from='resume' and ensure the checkpoint exists in out_dir.
WandB integration
Enable Weights & Biases logging
When enabled, the following metrics are logged:
Training loss
Validation loss
Learning rate
Model FLOPs Utilization (MFU)
Use gradient accumulation to simulate larger batch sizes without running out of memory. Effective batch size = batch_size * gradient_accumulation_steps * num_gpus.
Enable compilation with compile=True to use PyTorch 2.0’s optimizations for faster training (requires PyTorch >= 2.0).
Use bfloat16 if your GPU supports it (requires Ampere or newer). It provides better numerical stability than float16 without requiring gradient scaling.
Allow TF32 (enabled by default) for ~20% speedup on Ampere GPUs without loss of accuracy.