Skip to main content

Overview

The RL stage fine-tunes an SFT model using reinforcement learning on GSM8K math problems. This improves:
  • Mathematical reasoning
  • Tool use (Python calculator)
  • Answer accuracy
The implementation is a simplified variant of GRPO/REINFORCE:
  1. No trust region (no KL penalty to reference model)
  2. On-policy (no PPO ratio clipping needed)
  3. Token-level advantage normalization (DAPO style)
  4. Mean-centering instead of z-score normalization

Quick Start

Single GPU:
python -m scripts.chat_rl
Multi-GPU (8 GPUs):
torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl \
  --run=default

Loading Models

--model-tag
str
default:"None"
Model tag to load from chatsft_checkpoints/. Loads the SFT checkpoint.
--model-step
int
default:"None"
Specific checkpoint step to load. If not specified, loads the latest checkpoint.

Training Configuration

Training Horizon

--num-epochs
int
default:"1"
Number of epochs over the GSM8K training set (7,473 problems).

Batch Sizes and Sampling

--device-batch-size
int
default:"8"
Max batch size per forward pass. Controls memory usage.
--examples-per-step
int
default:"16"
Total examples per optimization step across all ranks.Must be divisible by number of GPUs. Each GPU processes examples_per_step / world_size examples.
--num-samples
int
default:"16"
Number of samples to generate per example/question.Total sequences per step = examples_per_step × num_samples.

Generation Parameters

--max-new-tokens
int
default:"256"
Maximum tokens to generate per sample (answer length limit).
--temperature
float
default:"1.0"
Sampling temperature. 1.0 = standard sampling, 0.0 = greedy.
--top-k
int
default:"50"
Top-k sampling. 0 = disabled (sample from full distribution).

Optimization

--matrix-lr
float
default:"0.02"
Learning rate for transformer matrices (Muon optimizer).
--embedding-lr
float
default:"0.2"
Learning rate for input embedding (Adam).
--unembedding-lr
float
default:"0.004"
Learning rate for output unembedding (Adam).
--weight-decay
float
default:"0.0"
Weight decay for Adam optimizer. Default is 0 (no regularization).
--init-lr-frac
float
default:"0.05"
Initial learning rate as fraction of base LR.Example: matrix_lr × init_lr_frac = 0.02 × 0.05 = 0.001 (starting LR).

Learning Rate Schedule

Simple linear rampdown to zero:
lr_multiplier = 1.0 - (step / num_steps)

Evaluation and Checkpointing

--eval-every
int
default:"60"
Evaluate pass@k on validation set every N steps.
--eval-examples
int
default:"400"
Number of validation examples for pass@k evaluation.
--save-every
int
default:"60"
Save checkpoint every N steps.

Logging

--run
str
default:"dummy"
Wandb run name. Set to “dummy” to disable wandb logging.
--device-type
str
default:""
Device type: cuda, cpu, or mps. Empty string = autodetect.
--dtype
str
default:"bfloat16"
Precision: float32 or bfloat16.

RL Algorithm Details

Rollout Phase

For each training example:
  1. Render prompt: Tokenize conversation, delete assistant’s answer, keep <|assistant_start|> to prime completion
  2. Sample completions: Generate num_samples completions using batched generation
  3. Calculate rewards: Check each completion against ground truth (1.0 if correct, 0.0 if wrong)
  4. Compute advantages: Mean-center rewards: advantage = reward - mean(rewards)

Training Phase

Policy gradient objective:
# Calculate log probabilities for generated tokens
logp = -model(inputs, targets, loss_reduction='none')  # (B, T)

# Weight by advantages (token-level broadcasting)
pg_objective = (logp * advantages.unsqueeze(-1)).sum()

# Normalize by valid tokens
pg_objective = pg_objective / num_valid_tokens

# Minimize negative objective
loss = -pg_objective
loss.backward()
Key differences from standard GRPO:
  • No KL penalty: No reference model, no trust region
  • No PPO clipping: On-policy, so no importance sampling ratio
  • Token-level advantages: Advantages broadcast to all tokens (DAPO style)
  • Mean centering only: No division by standard deviation

Masking

The loss only trains on:
  • Generated tokens (not prompt tokens)
  • Non-tool tokens (not forced Python calculator invocations)
Prompt and tool tokens have target = -1 (ignore_index).

GSM8K Task

GSM8K is a dataset of grade school math word problems. The task:
  1. Read the question (e.g., “Janet has 16 apples. She eats 3. How many remain?”)
  2. Generate a solution (with optional Python tool use)
  3. Extract the final numerical answer
  4. Reward = 1.0 if answer matches ground truth, else 0.0
Training set: 7,473 problems
Validation set: 1,319 problems

Example Workflows

Basic RL training

torchrun --nproc_per_node=8 -m scripts.chat_rl \
  --model-tag=d12 \
  --run=rl_d12

RL with more samples per question

torchrun --nproc_per_node=8 -m scripts.chat_rl \
  --model-tag=d20 \
  --num-samples=32 \
  --examples-per-step=8 \
  --run=rl_d20_32samples
Total sequences per step = 8 × 32 = 256.

RL with higher temperature

torchrun --nproc_per_node=8 -m scripts.chat_rl \
  --model-tag=d16 \
  --temperature=1.2 \
  --run=rl_d16_temp12
Higher temperature = more exploration.

Multiple epochs

torchrun --nproc_per_node=8 -m scripts.chat_rl \
  --model-tag=d20 \
  --num-epochs=3 \
  --run=rl_d20_3epochs

Conservative learning rate

torchrun --nproc_per_node=8 -m scripts.chat_rl \
  --model-tag=d12 \
  --init-lr-frac=0.02 \
  --run=rl_d12_conservative

Output

Checkpoints are saved to $NANOCHAT_BASE_DIR/chatrl_checkpoints/{model_tag}/:
  • step_{N}_model.pt - Model weights
  • step_{N}_meta.json - Metadata
Note: Optimizer state is NOT saved (not needed for RL).

Monitoring

Key metrics logged to console and wandb:

Rollout Metrics

  • reward - Average reward across sampled completions (0 to 1)
  • sequence_length - Average length of generated completions

Evaluation Metrics

  • pass@k - Success rate with k samples (k=1,2,…,device_batch_size)
    • pass@1 = greedy accuracy
    • pass@k = at least one correct answer in k samples

Training Metrics

  • lrm - Learning rate multiplier (1.0 → 0.0 over training)

Pass@k Evaluation

Pass@k measures the probability of generating at least one correct answer in k attempts:
pass@k = P(any of k samples is correct)
Evaluated at k=1,2,…,device_batch_size on validation set. Typical progression:
  • After SFT: pass@1 ≈ 20-30%, pass@16 ≈ 40-50%
  • After RL: pass@1 ≈ 30-40%, pass@16 ≈ 50-65%

Performance Tips

  1. Memory management: Reduce --device-batch-size if OOM during generation
  2. Exploration: Increase --temperature for more diverse rollouts
  3. Sample efficiency: Increase --num-samples for better advantage estimates
  4. Throughput: Batch size should be divisible by world size
  5. Evaluation cost: Reduce --eval-examples for faster evals

Tool Use

GSM8K solutions can invoke the Python calculator tool:
<|assistant_start|>Let's calculate:
<|python_start|>16 - 3<|python_end|>
<|output_start|>13<|output_end|>
Janet has 13 apples remaining.<|assistant_end|>
The RL training:
  • Does NOT train on tool invocation tokens (masked out)
  • Does NOT train on tool output tokens (masked out)
  • Only trains on text reasoning and final answer
This ensures the model learns when to use tools, not how to compute (Python does that).

Build docs developers (and LLMs) love