TRL is designed with modularity in mind so that users can efficiently customize the training loop for their needs. The techniques below apply to most (if not all) trainers in TRL.
The examples on this page use [DPOTrainer], but the same customization patterns apply across all TRL trainers.
Custom optimizers and schedulers
By default, TRL trainers create a torch.optim.AdamW optimizer. You can pass a custom optimizer directly:
from datasets import load_dataset
from torch import optim
from transformers import AutoModelForCausalLM
from trl import DPOTrainer
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
optimizer = optim.SGD(model.parameters(), lr=1e-6)
trainer = DPOTrainer(
model=model,
train_dataset=dataset,
optimizers=(optimizer, None),
)
trainer.train()
Adding a learning rate scheduler
Pass both optimizer and scheduler as a tuple via optimizers:
from torch import optim
optimizer = optim.AdamW(model.parameters(), lr=1e-6)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
trainer = DPOTrainer(..., optimizers=(optimizer, lr_scheduler))
8-bit reference models
TRL supports all keyword arguments accepted by from_pretrained, including load_in_8bit via BitsAndBytesConfig for more memory-efficient fine-tuning:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
ref_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct",
quantization_config=quantization_config,
)
trainer = DPOTrainer(..., ref_model=ref_model)
See the Transformers PEFT docs for more on 8-bit and 4-bit model loading.
Custom callbacks
Callbacks let you execute code at specific points during training — useful for custom logging, monitoring, or early stopping.
from transformers import TrainerCallback
class CustomLoggingCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is not None:
print(f"Step {state.global_step}: {logs}")
trainer = DPOTrainer(..., callbacks=[CustomLoggingCallback()])
Callbacks inherit from transformers.TrainerCallback. You can override any lifecycle hook such as on_train_begin, on_epoch_end, on_evaluate, and more.
Custom evaluation metrics
Define a compute_metrics function and pass it to the trainer. The function receives an EvalPrediction object containing logits and labels:
def compute_metrics(eval_preds):
logits, labels = eval_preds
# Add your metric computation here
return {"custom_metric": 0.0}
training_args = DPOConfig(..., eval_strategy="steps", eval_steps=100)
trainer = DPOTrainer(
...,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
)
Mixed precision training
Mixed precision can significantly speed up training and reduce memory usage. Set bf16=True or fp16=True in the training config:
# bfloat16 — recommended for Ampere GPUs (A100, RTX 30xx) and newer
training_args = DPOConfig(..., bf16=True)
# float16 — use for older GPUs
training_args = DPOConfig(..., fp16=True)
Use bf16=True on Ampere (A100, RTX 30xx) or newer GPUs. Use fp16=True on older architectures.
Gradient accumulation
Gradient accumulation simulates larger batch sizes on limited GPU memory by accumulating gradients over multiple steps before updating weights:
# Simulate a batch size of 32 with per_device_train_batch_size=4 and 8 accumulation steps
training_args = DPOConfig(
...,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
)
The effective batch size is:
effective_batch_size = per_device_train_batch_size × num_devices × gradient_accumulation_steps