Skip to main content

Overview

The training module provides a complete implementation for training GPT models with support for:
  • Single GPU and distributed data parallel (DDP) training
  • Mixed precision training (float16/bfloat16)
  • Gradient accumulation
  • Learning rate scheduling with warmup and cosine decay
  • Checkpointing and resumption
  • WandB integration for experiment tracking

Training modes

You can run the training script in multiple configurations:
python train.py --batch_size=32 --compile=False
If your cluster does not have Infiniband, prepend NCCL_IB_DISABLE=1 to the commands.

Key functions

get_batch

get_batch(split)
Loads a batch of training or validation data using memory-mapped files.
split
str
required
Data split to load from: 'train' or 'val'
Returns: Tuple of (x, y)
  • x: Input token sequences of shape (batch_size, block_size)
  • y: Target token sequences of shape (batch_size, block_size), shifted by one position

estimate_loss

@torch.no_grad()
estimate_loss()
Computes accurate loss estimates over multiple batches for both training and validation splits. Returns: Dictionary with keys 'train' and 'val', each containing the mean loss
eval_iters
int
default:"200"
Number of iterations to average over (configured globally)

get_lr

get_lr(it)
Computes the learning rate for a given iteration using cosine decay with linear warmup.
it
int
required
Current iteration number
Returns: float - Learning rate for this iteration

Training loop structure

The main training loop performs the following steps:

1. Learning rate scheduling

lr = get_lr(iter_num) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
    param_group['lr'] = lr

2. Periodic evaluation

if iter_num % eval_interval == 0 and master_process:
    losses = estimate_loss()
    print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    # Save checkpoint if validation loss improved or always_save_checkpoint=True

3. Forward and backward pass

With gradient accumulation to simulate larger batch sizes:
for micro_step in range(gradient_accumulation_steps):
    if ddp:
        # Only sync gradients on the last micro step
        model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
    with ctx:  # Automatic mixed precision context
        logits, loss = model(X, Y)
        loss = loss / gradient_accumulation_steps
    # Async prefetch next batch
    X, Y = get_batch('train')
    # Backward pass with gradient scaling for fp16
    scaler.scale(loss).backward()

4. Gradient clipping and optimizer step

if grad_clip != 0.0:
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

5. Logging

if iter_num % log_interval == 0 and master_process:
    lossf = loss.item() * gradient_accumulation_steps
    if local_iter_num >= 5:
        mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
        running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
    print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")

Configuration parameters

I/O settings

out_dir
str
default:"'out'"
Directory for saving checkpoints
eval_interval
int
default:"2000"
How often to evaluate on val set and save checkpoints
log_interval
int
default:"1"
How often to log training metrics
eval_iters
int
default:"200"
Number of iterations for loss estimation
eval_only
bool
default:"False"
If True, exit after first evaluation (useful for testing)
always_save_checkpoint
bool
default:"True"
If True, save checkpoint after each eval even if val loss didn’t improve
init_from
str
default:"'scratch'"
Initialization mode: 'scratch', 'resume', or a GPT-2 variant ('gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl')

Data settings

dataset
str
default:"'openwebtext'"
Name of dataset (must have corresponding data// directory)
gradient_accumulation_steps
int
default:"40"
Accumulate gradients over this many steps to simulate larger batches
batch_size
int
default:"12"
Micro-batch size (per GPU if using DDP)
block_size
int
default:"1024"
Context length for training sequences

Model architecture

n_layer
int
default:"12"
Number of transformer layers
n_head
int
default:"12"
Number of attention heads
n_embd
int
default:"768"
Embedding dimension
dropout
float
default:"0.0"
Dropout rate (0.0 for pretraining, 0.1+ for finetuning)
bias
bool
default:"False"
Use bias in Linear and LayerNorm layers

Optimizer settings

learning_rate
float
default:"6e-4"
Maximum learning rate
max_iters
int
default:"600000"
Total number of training iterations
weight_decay
float
default:"1e-1"
Weight decay coefficient
beta1
float
default:"0.9"
AdamW beta1 parameter
beta2
float
default:"0.95"
AdamW beta2 parameter
grad_clip
float
default:"1.0"
Gradient clipping threshold (0.0 to disable)

Learning rate decay

decay_lr
bool
default:"True"
Enable learning rate decay
warmup_iters
int
default:"2000"
Number of warmup iterations
lr_decay_iters
int
default:"600000"
Iterations for learning rate decay (should be ~= max_iters)
min_lr
float
default:"6e-5"
Minimum learning rate (should be ~= learning_rate/10)

System settings

device
str
default:"'cuda'"
Device to train on: 'cpu', 'cuda', 'cuda:0', 'cuda:1', 'mps', etc.
dtype
str
default:"'bfloat16' or 'float16'"
Data type for training: 'float32', 'bfloat16', or 'float16'. Automatically selects bfloat16 if supported.
compile
bool
default:"True"
Use PyTorch 2.0 compilation for faster training

DDP settings

backend
str
default:"'nccl'"
DDP backend: 'nccl' (recommended for CUDA) or 'gloo'

Checkpointing

Checkpoints are saved to {out_dir}/ckpt.pt and contain:
{
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'model_args': model_args,
    'iter_num': iter_num,
    'best_val_loss': best_val_loss,
    'config': config,
}
To resume training from a checkpoint, set init_from='resume' and ensure the checkpoint exists in out_dir.

WandB integration

wandb_log
bool
default:"False"
Enable Weights & Biases logging
wandb_project
str
default:"'owt'"
WandB project name
wandb_run_name
str
default:"'gpt2'"
WandB run name
When enabled, the following metrics are logged:
  • Training loss
  • Validation loss
  • Learning rate
  • Model FLOPs Utilization (MFU)

Performance tips

Use gradient accumulation to simulate larger batch sizes without running out of memory. Effective batch size = batch_size * gradient_accumulation_steps * num_gpus.
Enable compilation with compile=True to use PyTorch 2.0’s optimizations for faster training (requires PyTorch >= 2.0).
Use bfloat16 if your GPU supports it (requires Ampere or newer). It provides better numerical stability than float16 without requiring gradient scaling.
Allow TF32 (enabled by default) for ~20% speedup on Ampere GPUs without loss of accuracy.

Build docs developers (and LLMs) love