Skip to main content

Trainer

from modern_llm.training.trainer_base import Trainer
Causal LM trainer with gradient accumulation and automatic mixed precision (AMP) support. Handles the core training loop, evaluation, checkpointing, and logging.

Parameters

model
nn.Module
required
The PyTorch model to train. Must return a loss when labels are provided.
optimizer
Optimizer
required
PyTorch optimizer instance (e.g., AdamW) for parameter updates.
train_dataloader
Iterable
required
DataLoader yielding training batches with keys input_ids, attention_mask, and labels.
config
TrainingConfig
required
Training configuration containing hyperparameters, logging settings, and output directory.
eval_dataloader
Optional[Iterable]
default:"None"
Optional validation DataLoader. If provided, evaluation runs every config.eval_every steps.
lr_scheduler
Optional[_LRScheduler]
default:"None"
Optional learning rate scheduler that steps after each optimizer update.

Attributes

device
torch.device
Device (CPU or CUDA) where model and data reside.
logger
logging.Logger
Logger instance for training metrics and events.
use_amp
bool
Whether automatic mixed precision is enabled (fp16 or bf16).
scaler
Optional[GradScaler]
Gradient scaler for fp16 mixed precision. None for bf16 or fp32.
global_step
int
Total number of optimizer steps completed (after gradient accumulation).
micro_step
int
Total number of forward/backward passes (before gradient accumulation).

Methods

train

trainer.train() -> None
Run the full training loop until max_steps is reached. Handles gradient accumulation, logging, evaluation, and checkpointing according to the training config. Preconditions:
  • Model, optimizer, and dataloaders must be initialized
Postconditions:
  • Checkpoints and logs emitted per configuration
  • Final checkpoint saved at {run_name}_final.pt

evaluate

trainer.evaluate() -> Dict[str, float]
Run evaluation on the validation set. Returns: Dictionary containing:
  • loss: Average validation loss
  • perplexity: Exponential of average loss
Example:
metrics = trainer.evaluate()
print(f"Validation loss: {metrics['loss']:.4f}")
print(f"Perplexity: {metrics['perplexity']:.2f}")

Usage

import torch
from torch.utils.data import DataLoader
from modern_llm.training.trainer_base import Trainer
from modern_llm.config import TrainingConfig
from modern_llm.models import ModernDecoderLM
from pathlib import Path

# Initialize model
model = ModernDecoderLM.from_config(config)

# Setup optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.01,
)

# Create training config
train_config = TrainingConfig(
    run_name="my-training",
    dataset_name="wikitext",
    tokenizer_name="gpt2",
    output_dir=Path("experiments/runs"),
    batch_size=32,
    micro_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=3e-4,
    max_steps=10000,
    warmup_steps=500,
    eval_every=1000,
    save_every=2000,
    log_every=100,
    mixed_precision="bf16",
)

# Create trainer
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_loader,
    eval_dataloader=eval_loader,
    config=train_config,
    lr_scheduler=scheduler,
)

# Run training
trainer.train()

Features

Gradient Accumulation Accumulates gradients over multiple micro-batches before performing an optimizer step. Effective batch size = micro_batch_size × gradient_accumulation_steps. Mixed Precision Training Supports fp16, bf16, and fp32 training modes:
  • fp16: Uses gradient scaling for numerical stability
  • bf16: No gradient scaling needed, better dynamic range
  • fp32: Full precision training
Model Compilation If config.compile_model=True and PyTorch 2.0+ is available, applies torch.compile() for faster training. Gradient Clipping Clips gradients by norm if config.max_grad_norm > 0 to prevent exploding gradients. Automatic Checkpointing Saves model and optimizer state every save_every steps, plus a final checkpoint at completion.

Build docs developers (and LLMs) love