Skip to main content

TrainingConfig

Hyperparameters and bookkeeping for a single training or finetuning run. Parameters capture common training heuristics from GPT-style scaling (Kaplan et al., 2020) such as gradient accumulation, mixed precision, and logging cadence.

Required parameters

run_name
str
required
Name of the training run for identification and logging.
dataset_name
str
required
Name or path of the dataset to use for training.
tokenizer_name
str
required
Name or path of the tokenizer (e.g., “gpt2”, “meta-llama/Llama-2-7b-hf”).
output_dir
Path
required
Directory to save checkpoints, logs, and artifacts. Created automatically if it doesn’t exist.
batch_size
int
required
Global batch size across all gradient accumulation steps. Must be positive.
micro_batch_size
int
required
Batch size per forward/backward pass. Must be positive and not exceed batch_size.
gradient_accumulation_steps
int
required
Number of micro-batches to accumulate before updating weights. Must be positive.
learning_rate
float
required
Peak learning rate for the optimizer. Must be positive.
max_steps
int
required
Maximum number of training steps (optimizer updates). Must be positive.

Optional parameters

warmup_steps
int
default:"0"
Number of warmup steps for learning rate schedule. Must be non-negative.
weight_decay
float
default:"0.0"
Weight decay coefficient for AdamW optimizer. Must be non-negative.
max_grad_norm
float
default:"1.0"
Maximum gradient norm for gradient clipping. Must be positive.
eval_every
int
default:"500"
Evaluate on validation set every N steps. Must be non-negative.
save_every
int
default:"500"
Save checkpoint every N steps. Must be non-negative.
log_every
int
default:"50"
Log training metrics every N steps. Must be non-negative.
seed
Optional[int]
default:"42"
Random seed for reproducibility. Set to None for non-deterministic training.
mixed_precision
Literal['bf16', 'fp16', 'fp32']
default:"bf16"
Mixed precision training dtype. “bf16” recommended for modern GPUs.
gradient_checkpointing
bool
default:"True"
Trade compute for memory by recomputing activations during backward pass.
compile_model
bool
default:"True"
Use torch.compile for significant speedup on modern GPUs (PyTorch 2.0+).

Example

from pathlib import Path
from modern_llm.config import TrainingConfig

# Pretraining config
pretrain_config = TrainingConfig(
    run_name="gpt2-pretrain",
    dataset_name="wikitext-103-v1",
    tokenizer_name="gpt2",
    output_dir=Path("experiments/pretrain"),
    batch_size=256,
    micro_batch_size=8,
    gradient_accumulation_steps=32,  # 256 / 8 = 32
    learning_rate=3e-4,
    max_steps=100000,
    warmup_steps=2000,
    weight_decay=0.1,
    max_grad_norm=1.0,
    eval_every=1000,
    save_every=5000,
    log_every=100,
    seed=42,
    mixed_precision="bf16",
    gradient_checkpointing=True,
    compile_model=True,
)

# Fine-tuning config (lower learning rate, smaller batch)
sft_config = TrainingConfig(
    run_name="gpt2-sft",
    dataset_name="tatsu-lab/alpaca",
    tokenizer_name="gpt2",
    output_dir=Path("experiments/sft"),
    batch_size=32,
    micro_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    max_steps=5000,
    warmup_steps=100,
    weight_decay=0.01,
    eval_every=500,
    save_every=1000,
    log_every=50,
)

# DPO config (even smaller learning rate)
dpo_config = TrainingConfig(
    run_name="gpt2-dpo",
    dataset_name="Anthropic/hh-rlhf",
    tokenizer_name="gpt2",
    output_dir=Path("experiments/dpo"),
    batch_size=16,
    micro_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=5e-6,
    max_steps=2000,
    warmup_steps=50,
    weight_decay=0.0,
    eval_every=500,
    save_every=500,
)

Validation rules

  • batch_size, micro_batch_size, gradient_accumulation_steps, and max_steps must be positive
  • micro_batch_size cannot exceed batch_size
  • warmup_steps, eval_every, save_every, and log_every must be non-negative
  • learning_rate and max_grad_norm must be positive
  • weight_decay must be non-negative
  • mixed_precision must be one of “bf16”, “fp16”, or “fp32”
  • output_dir is automatically created if it doesn’t exist

Build docs developers (and LLMs) love