Overview
The RL stage fine-tunes an SFT model using reinforcement learning on GSM8K math problems. This improves:- Mathematical reasoning
- Tool use (Python calculator)
- Answer accuracy
- No trust region (no KL penalty to reference model)
- On-policy (no PPO ratio clipping needed)
- Token-level advantage normalization (DAPO style)
- Mean-centering instead of z-score normalization
Quick Start
Single GPU:Loading Models
Model tag to load from
chatsft_checkpoints/. Loads the SFT checkpoint.Specific checkpoint step to load. If not specified, loads the latest checkpoint.
Training Configuration
Training Horizon
Number of epochs over the GSM8K training set (7,473 problems).
Batch Sizes and Sampling
Max batch size per forward pass. Controls memory usage.
Total examples per optimization step across all ranks.Must be divisible by number of GPUs. Each GPU processes
examples_per_step / world_size examples.Number of samples to generate per example/question.Total sequences per step =
examples_per_step × num_samples.Generation Parameters
Maximum tokens to generate per sample (answer length limit).
Sampling temperature. 1.0 = standard sampling, 0.0 = greedy.
Top-k sampling. 0 = disabled (sample from full distribution).
Optimization
Learning rate for transformer matrices (Muon optimizer).
Learning rate for input embedding (Adam).
Learning rate for output unembedding (Adam).
Weight decay for Adam optimizer. Default is 0 (no regularization).
Initial learning rate as fraction of base LR.Example:
matrix_lr × init_lr_frac = 0.02 × 0.05 = 0.001 (starting LR).Learning Rate Schedule
Simple linear rampdown to zero:Evaluation and Checkpointing
Evaluate pass@k on validation set every N steps.
Number of validation examples for pass@k evaluation.
Save checkpoint every N steps.
Logging
Wandb run name. Set to “dummy” to disable wandb logging.
Device type: cuda, cpu, or mps. Empty string = autodetect.
Precision: float32 or bfloat16.
RL Algorithm Details
Rollout Phase
For each training example:- Render prompt: Tokenize conversation, delete assistant’s answer, keep
<|assistant_start|>to prime completion - Sample completions: Generate
num_samplescompletions using batched generation - Calculate rewards: Check each completion against ground truth (1.0 if correct, 0.0 if wrong)
- Compute advantages: Mean-center rewards:
advantage = reward - mean(rewards)
Training Phase
Policy gradient objective:- No KL penalty: No reference model, no trust region
- No PPO clipping: On-policy, so no importance sampling ratio
- Token-level advantages: Advantages broadcast to all tokens (DAPO style)
- Mean centering only: No division by standard deviation
Masking
The loss only trains on:- Generated tokens (not prompt tokens)
- Non-tool tokens (not forced Python calculator invocations)
target = -1 (ignore_index).
GSM8K Task
GSM8K is a dataset of grade school math word problems. The task:- Read the question (e.g., “Janet has 16 apples. She eats 3. How many remain?”)
- Generate a solution (with optional Python tool use)
- Extract the final numerical answer
- Reward = 1.0 if answer matches ground truth, else 0.0
Validation set: 1,319 problems
Example Workflows
Basic RL training
RL with more samples per question
RL with higher temperature
Multiple epochs
Conservative learning rate
Output
Checkpoints are saved to$NANOCHAT_BASE_DIR/chatrl_checkpoints/{model_tag}/:
step_{N}_model.pt- Model weightsstep_{N}_meta.json- Metadata
Monitoring
Key metrics logged to console and wandb:Rollout Metrics
- reward - Average reward across sampled completions (0 to 1)
- sequence_length - Average length of generated completions
Evaluation Metrics
- pass@k - Success rate with k samples (k=1,2,…,device_batch_size)
pass@1= greedy accuracypass@k= at least one correct answer in k samples
Training Metrics
- lrm - Learning rate multiplier (1.0 → 0.0 over training)
Pass@k Evaluation
Pass@k measures the probability of generating at least one correct answer in k attempts:- After SFT: pass@1 ≈ 20-30%, pass@16 ≈ 40-50%
- After RL: pass@1 ≈ 30-40%, pass@16 ≈ 50-65%
Performance Tips
- Memory management: Reduce
--device-batch-sizeif OOM during generation - Exploration: Increase
--temperaturefor more diverse rollouts - Sample efficiency: Increase
--num-samplesfor better advantage estimates - Throughput: Batch size should be divisible by world size
- Evaluation cost: Reduce
--eval-examplesfor faster evals
Tool Use
GSM8K solutions can invoke the Python calculator tool:- Does NOT train on tool invocation tokens (masked out)
- Does NOT train on tool output tokens (masked out)
- Only trains on text reasoning and final answer