TrainerCallback subclasses that extend Hugging Face Trainer with reinforcement-learning-specific features such as exponential moving average weight tracking, reference model synchronization, completion logging, and third-party observability integrations.
Import all callbacks from the top-level trl package:
BEMACallback
BEMACallback implements Bias-Corrected Exponential Moving Average (BEMA), introduced in Block & Zhang (2025). It maintains a running shadow model whose weights track the training model via a bias-corrected EMA scheme:
where decays with the step count. The EMA itself is updated as:
At the end of training the shadow model is saved to {output_dir}/bema/.
The BEMA buffers live on a separate device (default
"cpu") to avoid out-of-memory errors on the training accelerator.Signature
Parameters
Update the BEMA shadow model every this many steps. Denoted in the paper.
Exponent controlling the EMA decay factor . Set to
0.0 to disable EMA.Exponent controlling the BEMA scaling factor . Set to
0.0 to disable bias correction.Initial offset in the weight decay schedule. Controls smoothness in early training by acting as a virtual starting age.
Burn-in steps before BEMA updates begin. The snapshot is taken at this step.
Step multiplier applied to the step count inside the decay schedule.
Floor value for the EMA decay factor .
Device for BEMA buffers. Should differ from the training device to avoid OOM errors.
Example
SyncRefModelCallback
SyncRefModelCallback periodically synchronizes a reference model toward the current training model using an exponential moving average blend controlled by ref_model_mixup_alpha. It is used by trainers such as DPOTrainer when a soft-update reference policy is desired.
The sync is triggered at every step where global_step % args.ref_model_sync_steps == 0.
DeepSpeed ZeRO Stage 3 is handled automatically: parameters are gathered across ranks before the blend is applied.
Signature
Parameters
The reference model to keep synchronized with the training model.
Accelerate
Accelerator instance used to unwrap the model before syncing. Pass None if not using Accelerate.The sync frequency (
ref_model_sync_steps) and blend coefficient (ref_model_mixup_alpha) are read from TrainingArguments at runtime, not from the callback constructor.Example
LogCompletionsCallback
LogCompletionsCallback generates model completions for prompts from the evaluation dataset at regular intervals and logs them as a table to Weights & Biases and/or Comet ML. This makes it easy to track qualitative output quality throughout training.
Signature
Parameters
The trainer instance to attach the callback to. Used to access the model, tokenizer, accelerator, and evaluation dataset.
Generation configuration used when producing completions. If not provided the model’s default config is used.
Number of prompts sampled from the evaluation dataset. Defaults to the full evaluation dataset.
Logging frequency in steps. Defaults to
trainer.args.eval_steps.Example
RichProgressCallback
RichProgressCallback replaces the default tqdm-based progress display with a Rich layout that shows training and evaluation progress bars alongside a live metrics table grouped by prefix.
This callback requires the
rich package: pip install rich.Signature
Example
WeaveCallback
WeaveCallback logs completions and optional scorer evaluations to Weights & Biases Weave during evaluation steps. It supports two modes:
- Tracing mode (
scorers=None): logs predictions for data exploration. - Evaluation mode (
scorersprovided): logs predictions with per-scorer scores and summary statistics.
EvaluationLogger for structured logging.
Signature
Parameters
Trainer instance to attach the callback to.
Weave project name for logging. If not provided, the callback tries the existing Weave client, then the active wandb run. Raises a
ValueError if none is available.Mapping of scorer names to scorer functions with signature
scorer(prompt: str, completion: str) -> float | int. When provided, enables evaluation mode.Generation configuration for producing completions.
Number of evaluation prompts to use. Defaults to the full evaluation dataset.
Name label for the dataset metadata in Weave.
Name label for the model metadata in Weave. Extracted automatically from
model.config._name_or_path if not provided.