Skip to main content
Reinforcement learning script that improves a chat model’s math reasoning using policy gradients on GSM8K.

Usage

# Single GPU
python -m scripts.chat_rl

# Distributed (8 GPUs)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default

Algorithm

Simplified GRPO (Group Relative Policy Optimization):
  1. No trust region (no KL regularization to reference model)
  2. On-policy (no PPO ratio+clip)
  3. DAPO-style token-level normalization
  4. Advantage = reward - mean (not z-score)
This reduces to a simple REINFORCE-style algorithm.

Parameters

Logging

--run
str
default:"dummy"
Weights & Biases run name. Use 'dummy' to disable wandb logging.

Runtime

--device-type
str
default:""
Device type: cuda, cpu, or mps. Empty string enables autodetection.
--dtype
str
default:"bfloat16"
Floating point precision: float32 or bfloat16.

Model Loading

--model-tag
str
default:"None"
Model tag to load from SFT checkpoints (e.g. d24).
--model-step
int
default:"None"
Model step to load. If not specified, loads the last checkpoint.

Training Horizon

--num-epochs
int
default:"1"
Number of epochs over GSM8K training set.

Batch Sizes / Sampling

--device-batch-size
int
default:"8"
Maximum batch size per forward pass.
--examples-per-step
int
default:"16"
Total examples per optimization step across all ranks.
--num-samples
int
default:"16"
Number of samples per example/question.

Generation

--max-new-tokens
int
default:"256"
Maximum tokens to generate per sample.
--temperature
float
default:"1.0"
Sampling temperature for generation.
--top-k
int
default:"50"
Top-k sampling. 0 = disabled.

Optimization

--embedding-lr
float
default:"0.2"
Learning rate for embedding parameters (Adam).
--unembedding-lr
float
default:"0.004"
Learning rate for unembedding parameters (Adam).
--matrix-lr
float
default:"0.02"
Learning rate for matrix parameters (Muon).
--weight-decay
float
default:"0.0"
Weight decay for embedding/unembedding parameters (Adam).
--init-lr-frac
float
default:"0.05"
Initial learning rate as fraction of base learning rate.

Evaluation / Checkpointing

--eval-every
int
default:"60"
Evaluate pass@k every N steps.
--eval-examples
int
default:"400"
Number of examples for pass@k evaluation.
--save-every
int
default:"60"
Save checkpoint every N steps.

Examples

Basic RL Training

torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl \
  --run=my-rl-run \
  --model-tag=d24

Multiple Epochs

torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl \
  --model-tag=d24 \
  --num-epochs=3

More Samples per Question

torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl \
  --model-tag=d24 \
  --num-samples=32 \
  --examples-per-step=8

Higher Temperature Exploration

torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl \
  --model-tag=d24 \
  --temperature=1.5 \
  --top-k=100

Reward Function

The reward is calculated by:
  1. Generating completions for GSM8K problems
  2. Extracting the final numerical answer from the completion
  3. Comparing with the ground truth answer
  4. Binary reward: 1.0 for correct, 0.0 for incorrect

Pass@k Evaluation

During evaluation, the script computes pass@k for k=1 to device_batch_size:
  • pass@1: At least one correct answer in top 1 sample
  • pass@k: At least one correct answer in top k samples
This measures the model’s ability to solve problems with multiple attempts.

Build docs developers (and LLMs) love