Skip to main content
TRL is designed with modularity in mind so that users can efficiently customize the training loop for their needs. The techniques below apply to most (if not all) trainers in TRL.
The examples on this page use [DPOTrainer], but the same customization patterns apply across all TRL trainers.

Custom optimizers and schedulers

By default, TRL trainers create a torch.optim.AdamW optimizer. You can pass a custom optimizer directly:
from datasets import load_dataset
from torch import optim
from transformers import AutoModelForCausalLM
from trl import DPOTrainer

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
optimizer = optim.SGD(model.parameters(), lr=1e-6)

trainer = DPOTrainer(
    model=model,
    train_dataset=dataset,
    optimizers=(optimizer, None),
)
trainer.train()

Adding a learning rate scheduler

Pass both optimizer and scheduler as a tuple via optimizers:
from torch import optim

optimizer = optim.AdamW(model.parameters(), lr=1e-6)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

trainer = DPOTrainer(..., optimizers=(optimizer, lr_scheduler))

8-bit reference models

TRL supports all keyword arguments accepted by from_pretrained, including load_in_8bit via BitsAndBytesConfig for more memory-efficient fine-tuning:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)
ref_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    quantization_config=quantization_config,
)

trainer = DPOTrainer(..., ref_model=ref_model)
See the Transformers PEFT docs for more on 8-bit and 4-bit model loading.

Custom callbacks

Callbacks let you execute code at specific points during training — useful for custom logging, monitoring, or early stopping.
from transformers import TrainerCallback


class CustomLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            print(f"Step {state.global_step}: {logs}")


trainer = DPOTrainer(..., callbacks=[CustomLoggingCallback()])
Callbacks inherit from transformers.TrainerCallback. You can override any lifecycle hook such as on_train_begin, on_epoch_end, on_evaluate, and more.

Custom evaluation metrics

Define a compute_metrics function and pass it to the trainer. The function receives an EvalPrediction object containing logits and labels:
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    # Add your metric computation here
    return {"custom_metric": 0.0}


training_args = DPOConfig(..., eval_strategy="steps", eval_steps=100)

trainer = DPOTrainer(
    ...,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

Mixed precision training

Mixed precision can significantly speed up training and reduce memory usage. Set bf16=True or fp16=True in the training config:
# bfloat16 — recommended for Ampere GPUs (A100, RTX 30xx) and newer
training_args = DPOConfig(..., bf16=True)

# float16 — use for older GPUs
training_args = DPOConfig(..., fp16=True)
Use bf16=True on Ampere (A100, RTX 30xx) or newer GPUs. Use fp16=True on older architectures.

Gradient accumulation

Gradient accumulation simulates larger batch sizes on limited GPU memory by accumulating gradients over multiple steps before updating weights:
# Simulate a batch size of 32 with per_device_train_batch_size=4 and 8 accumulation steps
training_args = DPOConfig(
    ...,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
)
The effective batch size is:
effective_batch_size = per_device_train_batch_size × num_devices × gradient_accumulation_steps

Build docs developers (and LLMs) love