Usage
Algorithm
Simplified GRPO (Group Relative Policy Optimization):- No trust region (no KL regularization to reference model)
- On-policy (no PPO ratio+clip)
- DAPO-style token-level normalization
- Advantage = reward - mean (not z-score)
Parameters
Logging
Weights & Biases run name. Use
'dummy' to disable wandb logging.Runtime
Device type:
cuda, cpu, or mps. Empty string enables autodetection.Floating point precision:
float32 or bfloat16.Model Loading
Model tag to load from SFT checkpoints (e.g.
d24).Model step to load. If not specified, loads the last checkpoint.
Training Horizon
Number of epochs over GSM8K training set.
Batch Sizes / Sampling
Maximum batch size per forward pass.
Total examples per optimization step across all ranks.
Number of samples per example/question.
Generation
Maximum tokens to generate per sample.
Sampling temperature for generation.
Top-k sampling.
0 = disabled.Optimization
Learning rate for embedding parameters (Adam).
Learning rate for unembedding parameters (Adam).
Learning rate for matrix parameters (Muon).
Weight decay for embedding/unembedding parameters (Adam).
Initial learning rate as fraction of base learning rate.
Evaluation / Checkpointing
Evaluate pass@k every N steps.
Number of examples for pass@k evaluation.
Save checkpoint every N steps.
Examples
Basic RL Training
Multiple Epochs
More Samples per Question
Higher Temperature Exploration
Reward Function
The reward is calculated by:- Generating completions for GSM8K problems
- Extracting the final numerical answer from the completion
- Comparing with the ground truth answer
- Binary reward: 1.0 for correct, 0.0 for incorrect
Pass@k Evaluation
During evaluation, the script computes pass@k for k=1 todevice_batch_size:
- pass@1: At least one correct answer in top 1 sample
- pass@k: At least one correct answer in top k samples