Skip to main content

Overview

GRPO (Group Relative Policy Optimization) is an online RL algorithm introduced in the paper DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models. It is a variant of PPO that reduces memory usage by replacing the value model with a group-relative advantage estimate. At each step, GRPO generates a group of completions per prompt, computes rewards for each completion, normalizes the rewards within the group to obtain advantages, and updates the policy to increase the probability of high-advantage completions. This approach has become the standard method for training reasoning models such as DeepSeek-R1.

Quick start

# train_grpo.py
from datasets import load_dataset
from trl import GRPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
)
trainer.train()
accelerate launch train_grpo.py

How GRPO works

1

Generate completions

At each training step, sample a batch of prompts and generate num_generations (G) completions per prompt.
2

Compute advantages

For each completion, compute a scalar reward. Normalize within the group:
  • Group normalization (default): subtract the group mean and divide by the group standard deviation.
  • Batch normalization: compute mean at group level but standard deviation at the batch level (scale_rewards="batch").
  • No scaling: disable normalization entirely (scale_rewards=False).
3

Estimate KL divergence

Use the Schulman approximator to estimate KL divergence between the policy and a fixed reference model. With beta=0.0 (default), no reference model is loaded.
4

Compute loss and update

Maximize advantages while penalizing deviation from the reference policy. The default loss type is "dapo", which normalizes by the number of active tokens in the batch to remove length bias.

Dataset format

The dataset must include a "prompt" column. All other columns are passed to reward functions as keyword arguments.
# Standard format
{"prompt": "Solve the equation 2x + 3 = 7.", "ground_truth": "2"}

# Conversational format
{"prompt": [{"role": "user", "content": "Solve the equation 2x + 3 = 7."}],
 "ground_truth": "2"}
For VLM training, include an image or images column alongside prompt.

Custom reward functions

A reward function must accept prompts, completions, completion_ids, and any dataset columns as keyword arguments, and return a list of floats (one per completion). Use **kwargs to accept all arguments.
# Reward based on answer correctness
import re

def reward_func(completions, ground_truth, **kwargs):
    matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
# Reward based on output format
import re

def format_reward_func(completions, **kwargs):
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]
Pass reward functions to the trainer:
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=[format_reward_func, reward_func],  # sum of all rewards
    train_dataset=dataset,
)
Reward functions can be async def coroutines. Multiple async functions are executed concurrently via asyncio.gather, so their latency overlaps.

Multi-task reward functions

Return None for samples that a reward function does not apply to. The trainer ignores None values and sums only valid rewards:
def math_reward_func(completions, task, **kwargs):
    return [1.0 if task == "math" and correct else None for ...]

Built-in rewards

TRL provides built-in reward functions in trl.rewards, including accuracy_reward for checking mathematical correctness.

Key configuration parameters

num_generations
int
default:"8"
Number of completions to generate per prompt (the group size G). The effective batch size must be divisible by this value.
max_completion_length
int | None
default:"256"
Maximum number of tokens to generate per completion.
temperature
float
default:"1.0"
Sampling temperature. Higher values produce more diverse completions.
top_p
float
default:"1.0"
Nucleus sampling cutoff. Set below 1.0 to restrict sampling to a smaller token set.
beta
float
default:"0.0"
KL coefficient controlling deviation from the reference model. When 0.0 (default), the reference model is not loaded. DeepSeek-R1 uses 0.001.
loss_type
str
default:"dapo"
Loss normalization strategy. Options: "dapo" (normalizes by active tokens in batch, default), "dr_grpo" (normalizes by max_completion_length), "grpo" (normalizes by sequence length, not recommended), "bnpo", "cispo", "sapo", "luspo", "vespo".
scale_rewards
str | bool
default:"group"
Reward scaling strategy. "group" (default): normalize within each prompt group. "batch": normalize across the entire batch. False: no scaling.
epsilon
float
default:"0.2"
Clipping range for the policy ratio in the surrogate objective.
num_iterations
int
default:"1"
Number of gradient update passes per generated batch (μ in the original paper). When greater than 1, uses the clipped surrogate objective.
mask_truncated_completions
bool
default:"false"
Exclude truncated completions from the loss. Recommended for training stability, especially with long-chain-of-thought responses.
reward_weights
list[float] | None
Per-function weights when using multiple reward functions. If None, all functions are weighted equally.
use_vllm
bool
default:"false"
Use vLLM for faster generation. Requires pip install trl[vllm].
vllm_mode
str
default:"colocate"
How to run vLLM: "colocate" (shares training GPUs) or "server" (separate process on dedicated GPUs).
vllm_gpu_memory_utilization
float
default:"0.3"
Fraction of GPU memory reserved for vLLM when running in colocate mode.

Accelerating generation with vLLM

Generation is typically the bottleneck in online RL training. vLLM can provide a significant speedup.
vLLM runs inside the trainer process and shares GPU memory with the training model:
from trl import GRPOConfig

training_args = GRPOConfig(
    use_vllm=True,  # vllm_mode="colocate" by default
)
In server mode, ensure the vLLM server uses different GPUs than the trainer. Use the CUDA_VISIBLE_DEVICES environment variable to separate them, or you may encounter NCCL errors.
By default, Truncated Importance Sampling is applied when using vLLM to correct for the training–inference mismatch between the two engines. Disable it with vllm_importance_sampling_correction=False.

Training with PEFT/LoRA

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
from trl.rewards import accuracy_reward
from peft import LoraConfig

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
    peft_config=LoraConfig(),
)
trainer.train()

Agent training

GRPO supports agentic workflows through tool use. Pass a list of Python functions as tools:
from trl import GRPOTrainer

def multiply(a: int, b: int) -> int:
    """
    Multiplies two integers.

    Args:
        a: The first integer.
        b: The second integer.

    Returns:
        The product of the two integers.
    """
    return a * b

trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B",
    reward_funcs=reward_func,
    train_dataset=dataset,
    tools=[multiply],
)
Tools must be Python functions with type-hinted arguments, return types, and a Google-style docstring. The model uses these to determine how to call each tool.

Logged metrics

MetricDescription
rewardOverall average reward (sum across functions, weighted by reward_weights)
reward_stdStandard deviation of summed rewards across the batch
completions/mean_lengthAverage length of generated completions
completions/clipped_ratioFraction of completions truncated at max_completion_length
entropyAverage token prediction entropy across completions
klAverage KL divergence from the reference model (only logged when beta > 0)
clip_ratio/region_meanFraction of tokens where the policy ratio was clipped

Build docs developers (and LLMs) love