Skip to main content
The trl.scripts module provides the building blocks for writing training scripts: a YAML-aware argument parser (TrlParser), common argument dataclasses (ScriptArguments, ModelConfig), a dataset mixture loader (get_dataset / DatasetMixtureConfig), and a logging initializer (init_zero_verbose).
from trl import (
    TrlParser,
    ScriptArguments,
    DatasetMixtureConfig,
    get_dataset,
    init_zero_verbose,
)
from trl.trainer import ModelConfig

TrlParser

TrlParser extends transformers.HfArgumentParser with support for YAML configuration files and environment variable injection. Pass --config path/to/config.yaml on the command line to load defaults from a file; command-line arguments always override config file values. The env key in the YAML file can set environment variables before the rest of the config is applied.

Signature

class TrlParser(HfArgumentParser):
    def __init__(
        self,
        dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
        **kwargs,
    )

Parameters

dataclass_types
DataClassType | Iterable[DataClassType]
One or more dataclass types to parse arguments into. None of the dataclasses may have a field named "config" (reserved for the config file path).

Methods

Parses command-line arguments and an optional YAML config file.
def parse_args_and_config(
    self,
    args: Iterable[str] | None = None,
    return_remaining_strings: bool = False,
    fail_with_unknown_args: bool = True,
    separate_remaining_strings: bool = False,
) -> tuple[DataClass, ...]
The config file (specified via --config) is loaded with yaml.safe_load. Its env section (if present) sets environment variables. All other keys are used as argument defaults. Raises ValueError for unknown config keys when fail_with_unknown_args=True.
Overrides argument defaults with values from keyword arguments (typically from a YAML config). Marks overridden arguments as no longer required.
def set_defaults_with_config(self, **kwargs) -> list[str]
Returns a list of string tokens for keys not recognized by the parser.

Example

env:
  TOKENIZERS_PARALLELISM: "false"
arg1: 23

ScriptArguments

A dataclass holding dataset-related arguments common to all TRL training scripts. Designed to be used with TrlParser.

Signature

@dataclass
class ScriptArguments:
    dataset_name: str | None = None
    dataset_config: str | None = None
    dataset_train_split: str = "train"
    dataset_test_split: str = "test"
    dataset_streaming: bool = False
    ignore_bias_buffers: bool = False

Fields

dataset_name
str
Path or name of the dataset to load via datasets.load_dataset. Ignored when DatasetMixtureConfig.datasets is provided.
dataset_config
str
Dataset configuration name, corresponding to the name argument of datasets.load_dataset. Ignored when a mixture config is used.
dataset_train_split
str
default:"'train'"
Dataset split to use for training.
dataset_test_split
str
default:"'test'"
Dataset split to use for evaluation.
dataset_streaming
bool
default:"False"
When True, loads the dataset in streaming mode.
ignore_bias_buffers
bool
default:"False"
Debug flag for distributed training. Fixes DDP issues with LM bias/mask buffers.

Example

from trl import TrlParser, ScriptArguments

parser = TrlParser(dataclass_types=[ScriptArguments])
(args,) = parser.parse_args_and_config()
print(args.dataset_name)

ModelConfig

A dataclass holding model loading and PEFT configuration, designed for use with TrlParser.

Signature

@dataclass
class ModelConfig:
    model_name_or_path: str | None = None
    model_revision: str = "main"
    dtype: str | None = "float32"
    trust_remote_code: bool = False
    attn_implementation: str | None = None
    use_peft: bool = False
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: list[str] | None = None
    lora_target_parameters: list[str] | None = None
    lora_modules_to_save: list[str] | None = None
    lora_task_type: str = "CAUSAL_LM"
    use_rslora: bool = False
    use_dora: bool = False
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    bnb_4bit_quant_type: str = "nf4"
    use_bnb_nested_quant: bool = False
    bnb_4bit_quant_storage: str | None = None

Key fields

model_name_or_path
str
HuggingFace Hub identifier or local path of the model checkpoint.
model_revision
str
default:"'main'"
Branch name, tag, or commit hash to load.
dtype
str
default:"'float32'"
Load dtype override. One of "auto", "bfloat16", "float16", "float32".
trust_remote_code
bool
default:"False"
Allow execution of custom model code from the Hub. Only enable for repositories you trust.
attn_implementation
str
Attention kernel to use (e.g., "flash_attention_2").
use_peft
bool
default:"False"
Enable PEFT/LoRA fine-tuning.
lora_r
int
default:"16"
LoRA rank.
lora_alpha
int
default:"32"
LoRA scaling factor.
lora_dropout
float
default:"0.05"
LoRA dropout probability.
lora_target_modules
list[str]
Module names to apply LoRA to.
lora_task_type
str
default:"'CAUSAL_LM'"
PEFT task type. Use "SEQ_CLS" for reward modeling.
use_rslora
bool
default:"False"
Use Rank-Stabilized LoRA (scales adapter by lora_alpha/√r instead of lora_alpha/r).
use_dora
bool
default:"False"
Enable Weight-Decomposed Low-Rank Adaptation (DoRA).
load_in_8bit
bool
default:"False"
Load in 8-bit precision (requires LoRA).
load_in_4bit
bool
default:"False"
Load in 4-bit precision (requires LoRA).
bnb_4bit_quant_type
str
default:"'nf4'"
4-bit quantization type: "fp4" or "nf4".
use_bnb_nested_quant
bool
default:"False"
Enable nested quantization (double quantization).

Example

from trl import TrlParser, ScriptArguments
from trl.trainer import ModelConfig

parser = TrlParser(dataclass_types=[ScriptArguments, ModelConfig])
(script_args, model_config) = parser.parse_args_and_config()

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    model_config.model_name_or_path,
    torch_dtype=model_config.dtype,
    trust_remote_code=model_config.trust_remote_code,
)

DatasetMixtureConfig

Configuration dataclass for loading and combining multiple datasets into a single training mixture. Each dataset in the mixture is described by a DatasetConfig entry.

Signature

@dataclass
class DatasetMixtureConfig:
    datasets: list[DatasetConfig] = field(default_factory=list)
    streaming: bool = False
    test_split_size: float | None = None

Fields

datasets
list[DatasetConfig]
List of individual dataset configurations. Each entry specifies a path, optional name, data_dir, data_files, split, and columns.
streaming
bool
default:"False"
Load all datasets in streaming mode.
test_split_size
float
If provided, the combined dataset is split into train and test subsets using this fraction as the test size.

YAML usage

datasets:
  - path: trl-lib/tldr
    split: train
  - path: trl-lib/ultrafeedback_binarized
    split: train
streaming: false
test_split_size: 0.05

get_dataset

Loads and concatenates a mixture of datasets described by a DatasetMixtureConfig. Returns a DatasetDict with a "train" key (and optionally a "test" key when test_split_size is set).

Signature

def get_dataset(mixture_config: DatasetMixtureConfig) -> DatasetDict

Parameters

mixture_config
DatasetMixtureConfig
Configuration specifying datasets, streaming, and optional test split.

Returns

datasets.DatasetDict — Combined dataset. Always contains a "train" split; also contains a "test" split if mixture_config.test_split_size is not None.

Example

from trl import DatasetMixtureConfig, get_dataset
from trl.scripts.utils import DatasetConfig

mixture_config = DatasetMixtureConfig(
    datasets=[DatasetConfig(path="trl-lib/tldr")]
)
dataset = get_dataset(mixture_config)
print(dataset)
# DatasetDict({
#     train: Dataset({features: ['prompt', 'completion'], num_rows: 116722})
# })

init_zero_verbose

Configures Python’s logging and warnings for minimal, clean output — suitable for the top of CLI training scripts. Uses RichHandler when the rich package is available, falling back to a standard StreamHandler.

Signature

def init_zero_verbose() -> None
Sets the root log level to ERROR and redirects warnings.showwarning to the logging system.

Example

from trl import init_zero_verbose

init_zero_verbose()  # call before any other imports

from transformers import AutoModelForCausalLM
# ...

Build docs developers (and LLMs) love