Trainer
Parameters
The PyTorch model to train. Must return a loss when labels are provided.
PyTorch optimizer instance (e.g., AdamW) for parameter updates.
DataLoader yielding training batches with keys
input_ids, attention_mask, and labels.Training configuration containing hyperparameters, logging settings, and output directory.
Optional validation DataLoader. If provided, evaluation runs every
config.eval_every steps.Optional learning rate scheduler that steps after each optimizer update.
Attributes
Device (CPU or CUDA) where model and data reside.
Logger instance for training metrics and events.
Whether automatic mixed precision is enabled (fp16 or bf16).
Gradient scaler for fp16 mixed precision. None for bf16 or fp32.
Total number of optimizer steps completed (after gradient accumulation).
Total number of forward/backward passes (before gradient accumulation).
Methods
train
max_steps is reached. Handles gradient accumulation, logging, evaluation, and checkpointing according to the training config.
Preconditions:
- Model, optimizer, and dataloaders must be initialized
- Checkpoints and logs emitted per configuration
- Final checkpoint saved at
{run_name}_final.pt
evaluate
loss: Average validation lossperplexity: Exponential of average loss
Usage
Features
Gradient Accumulation Accumulates gradients over multiple micro-batches before performing an optimizer step. Effective batch size =micro_batch_size × gradient_accumulation_steps.
Mixed Precision Training
Supports fp16, bf16, and fp32 training modes:
fp16: Uses gradient scaling for numerical stabilitybf16: No gradient scaling needed, better dynamic rangefp32: Full precision training
config.compile_model=True and PyTorch 2.0+ is available, applies torch.compile() for faster training.
Gradient Clipping
Clips gradients by norm if config.max_grad_norm > 0 to prevent exploding gradients.
Automatic Checkpointing
Saves model and optimizer state every save_every steps, plus a final checkpoint at completion.