Skip to main content
TRL provides a set of TrainerCallback subclasses that extend Hugging Face Trainer with reinforcement-learning-specific features such as exponential moving average weight tracking, reference model synchronization, completion logging, and third-party observability integrations. Import all callbacks from the top-level trl package:
from trl import (
    BEMACallback,
    LogCompletionsCallback,
    RichProgressCallback,
    SyncRefModelCallback,
    WeaveCallback,
)

BEMACallback

BEMACallback implements Bias-Corrected Exponential Moving Average (BEMA), introduced in Block & Zhang (2025). It maintains a running shadow model whose weights track the training model via a bias-corrected EMA scheme: θt=αt(θtθ0)+EMAt\theta_t' = \alpha_t \cdot (\theta_t - \theta_0) + \text{EMA}_t where αt=(ρ+γt)η\alpha_t = (\rho + \gamma \cdot t)^{-\eta} decays with the step count. The EMA itself is updated as: EMAt=(1βt)EMAt1+βtθt,βt=(ρ+γt)κ\text{EMA}_t = (1 - \beta_t) \cdot \text{EMA}_{t-1} + \beta_t \cdot \theta_t, \quad \beta_t = (\rho + \gamma \cdot t)^{-\kappa} At the end of training the shadow model is saved to {output_dir}/bema/.
The BEMA buffers live on a separate device (default "cpu") to avoid out-of-memory errors on the training accelerator.

Signature

class BEMACallback(TrainerCallback):
    def __init__(
        self,
        update_freq: int = 400,
        ema_power: float = 0.5,
        bias_power: float = 0.2,
        lag: int = 10,
        update_after: int = 0,
        multiplier: float = 1.0,
        min_ema_multiplier: float = 0.0,
        device: str = "cpu",
    )

Parameters

update_freq
int
default:"400"
Update the BEMA shadow model every this many steps. Denoted ϕ\phi in the paper.
ema_power
float
default:"0.5"
Exponent κ\kappa controlling the EMA decay factor βt\beta_t. Set to 0.0 to disable EMA.
bias_power
float
default:"0.2"
Exponent η\eta controlling the BEMA scaling factor αt\alpha_t. Set to 0.0 to disable bias correction.
lag
int
default:"10"
Initial offset ρ\rho in the weight decay schedule. Controls smoothness in early training by acting as a virtual starting age.
update_after
int
default:"0"
Burn-in steps τ\tau before BEMA updates begin. The snapshot θ0\theta_0 is taken at this step.
multiplier
float
default:"1.0"
Step multiplier γ\gamma applied to the step count inside the decay schedule.
min_ema_multiplier
float
default:"0.0"
Floor value for the EMA decay factor βt\beta_t.
device
str
default:"'cpu'"
Device for BEMA buffers. Should differ from the training device to avoid OOM errors.

Example

from trl import BEMACallback
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(output_dir="./output", max_steps=2000)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    callbacks=[BEMACallback(update_freq=400, device="cpu")],
)
trainer.train()
# Shadow model saved to ./output/bema/

SyncRefModelCallback

SyncRefModelCallback periodically synchronizes a reference model toward the current training model using an exponential moving average blend controlled by ref_model_mixup_alpha. It is used by trainers such as DPOTrainer when a soft-update reference policy is desired. The sync is triggered at every step where global_step % args.ref_model_sync_steps == 0.
DeepSpeed ZeRO Stage 3 is handled automatically: parameters are gathered across ranks before the blend is applied.

Signature

class SyncRefModelCallback(TrainerCallback):
    def __init__(
        self,
        ref_model: PreTrainedModel | torch.nn.Module,
        accelerator: Accelerator | None,
    )

Parameters

ref_model
PreTrainedModel | torch.nn.Module
The reference model to keep synchronized with the training model.
accelerator
Accelerator | None
Accelerate Accelerator instance used to unwrap the model before syncing. Pass None if not using Accelerate.
The sync frequency (ref_model_sync_steps) and blend coefficient (ref_model_mixup_alpha) are read from TrainingArguments at runtime, not from the callback constructor.

Example

from trl import SyncRefModelCallback
from accelerate import Accelerator

accelerator = Accelerator()
callback = SyncRefModelCallback(ref_model=ref_model, accelerator=accelerator)
trainer.add_callback(callback)

LogCompletionsCallback

LogCompletionsCallback generates model completions for prompts from the evaluation dataset at regular intervals and logs them as a table to Weights & Biases and/or Comet ML. This makes it easy to track qualitative output quality throughout training.
The trainer must have an evaluation dataset with a "prompt" column. A ValueError is raised at construction time if the dataset is absent.

Signature

class LogCompletionsCallback(TrainerCallback):
    def __init__(
        self,
        trainer: Trainer,
        generation_config: GenerationConfig | None = None,
        num_prompts: int | None = None,
        freq: int | None = None,
    )

Parameters

trainer
Trainer
The trainer instance to attach the callback to. Used to access the model, tokenizer, accelerator, and evaluation dataset.
generation_config
GenerationConfig
Generation configuration used when producing completions. If not provided the model’s default config is used.
num_prompts
int
Number of prompts sampled from the evaluation dataset. Defaults to the full evaluation dataset.
freq
int
Logging frequency in steps. Defaults to trainer.args.eval_steps.

Example

from trl import DPOTrainer, LogCompletionsCallback

trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # must have a "prompt" column
)
completions_callback = LogCompletionsCallback(trainer=trainer, num_prompts=32)
trainer.add_callback(completions_callback)
trainer.train()

RichProgressCallback

RichProgressCallback replaces the default tqdm-based progress display with a Rich layout that shows training and evaluation progress bars alongside a live metrics table grouped by prefix.
This callback requires the rich package: pip install rich.

Signature

class RichProgressCallback(TrainerCallback):
    def __init__(self)
No constructor arguments are required.

Example

from trl import RichProgressCallback
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    callbacks=[RichProgressCallback()],
)
trainer.train()

WeaveCallback

WeaveCallback logs completions and optional scorer evaluations to Weights & Biases Weave during evaluation steps. It supports two modes:
  • Tracing mode (scorers=None): logs predictions for data exploration.
  • Evaluation mode (scorers provided): logs predictions with per-scorer scores and summary statistics.
Both modes use Weave’s EvaluationLogger for structured logging.
The trainer must have an evaluation dataset with a "prompt" column. A ValueError is raised at construction time if absent.

Signature

class WeaveCallback(TrainerCallback):
    def __init__(
        self,
        trainer: Trainer,
        project_name: str | None = None,
        scorers: dict[str, Callable] | None = None,
        generation_config: GenerationConfig | None = None,
        num_prompts: int | None = None,
        dataset_name: str = "eval_dataset",
        model_name: str | None = None,
    )

Parameters

trainer
Trainer
Trainer instance to attach the callback to.
project_name
str
Weave project name for logging. If not provided, the callback tries the existing Weave client, then the active wandb run. Raises a ValueError if none is available.
scorers
dict[str, Callable]
Mapping of scorer names to scorer functions with signature scorer(prompt: str, completion: str) -> float | int. When provided, enables evaluation mode.
generation_config
GenerationConfig
Generation configuration for producing completions.
num_prompts
int
Number of evaluation prompts to use. Defaults to the full evaluation dataset.
dataset_name
str
default:"'eval_dataset'"
Name label for the dataset metadata in Weave.
model_name
str
Name label for the model metadata in Weave. Extracted automatically from model.config._name_or_path if not provided.

Example

from trl import DPOTrainer, WeaveCallback

trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Tracing mode — log predictions only
weave_callback = WeaveCallback(trainer=trainer, project_name="my-llm-training")
trainer.add_callback(weave_callback)
trainer.train()

Build docs developers (and LLMs) love