Overview
LoRA (Low-Rank Adaptation) enables efficient fine-tuning of large language models by training only a small number of additional parameters. This guide shows you how to use LoRA with rLLM for reinforcement learning.
Why LoRA?
Benefits:
- Reduced memory footprint (2-10x less GPU memory)
- Faster training (2-3x speedup)
- Smaller checkpoint sizes
- Multiple task-specific adapters from one base model
Trade-offs:
- Slightly lower performance vs. full fine-tuning
- Additional hyperparameters to tune (rank, alpha)
Quick Start
Installation
LoRA support is included in rLLM by default via PEFT:
pip install rllm # Includes PEFT dependency
Basic Example
Here’s a complete example using LoRA for math problem solving:
import hydra
from rllm.agents.math_agent import MathAgent
from rllm.data.dataset import DatasetRegistry
from rllm.environments.base.single_turn_env import SingleTurnEnvironment
from rllm.rewards.reward_fn import math_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer
@hydra.main(
config_path="pkg://rllm.trainer.config",
config_name="agent_ppo_trainer",
version_base=None
)
def main(config):
# Load datasets
train_dataset = DatasetRegistry.load_dataset("gsm8k", "train")
test_dataset = DatasetRegistry.load_dataset("gsm8k", "test")
# Configure LoRA
config.actor_rollout_ref.actor.model.lora_rank = 16
config.actor_rollout_ref.actor.model.lora_alpha = 32
config.actor_rollout_ref.actor.model.enable_lora = True
# Setup environment
env_args = {"reward_fn": math_reward_fn}
# Create trainer
trainer = AgentTrainer(
agent_class=MathAgent,
agent_args={},
env_class=SingleTurnEnvironment,
env_args=env_args,
config=config,
train_dataset=train_dataset,
val_dataset=test_dataset,
)
trainer.train()
if __name__ == "__main__":
main()
Adapted from examples/gsm8k_lora/train_gsm8k_with_lora.py:10.
Configuration
Enable LoRA
Set the LoRA flag in your config:config.actor_rollout_ref.actor.model.enable_lora = True
Or via command line:python train.py actor_rollout_ref.actor.model.enable_lora=True
Set LoRA rank
The rank controls the number of trainable parameters:config.actor_rollout_ref.actor.model.lora_rank = 16 # Common: 4, 8, 16, 32
- Lower rank (4-8): Fewer parameters, faster, less expressive
- Higher rank (16-32): More parameters, slower, more expressive
Start with rank=16 and adjust based on performance. Set LoRA alpha
Alpha controls the scaling of LoRA updates:config.actor_rollout_ref.actor.model.lora_alpha = 32
Common pattern: alpha = 2 * rankHigher alpha = stronger LoRA influence. Configure target modules (optional)
Specify which layers to apply LoRA to:config.actor_rollout_ref.actor.model.lora_target_modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
]
Default: Applies to all attention layers. Add "gate_proj", "up_proj", "down_proj" for MLP layers. Set dropout (optional)
Add regularization:config.actor_rollout_ref.actor.model.lora_dropout = 0.05
Default: 0.0 (no dropout). Use 0.05-0.1 if overfitting.
Complete Configuration Example
actor_rollout_ref:
actor:
model:
path: meta-llama/Llama-3.1-8B-Instruct
enable_lora: true
lora_rank: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
ppo_mini_batch_size: 256
ppo_micro_batch_size: 64
optim:
lr: 5e-5 # Higher LR works well with LoRA
Hyperparameter Selection
LoRA Rank
| Rank | Parameters | Use Case |
|---|
| 4 | ~1M | Simple tasks, very limited GPU memory |
| 8 | ~2M | Most tasks, good balance |
| 16 | ~4M | Complex tasks, recommended default |
| 32 | ~8M | Very complex tasks, approaching full fine-tuning |
| 64 | ~16M | Rarely needed, diminishing returns |
LoRA Alpha
Rule of thumb: alpha = 2 * rank
# Conservative (more stable)
config.actor_rollout_ref.actor.model.lora_rank = 16
config.actor_rollout_ref.actor.model.lora_alpha = 16
# Balanced (recommended)
config.actor_rollout_ref.actor.model.lora_rank = 16
config.actor_rollout_ref.actor.model.lora_alpha = 32
# Aggressive (faster adaptation)
config.actor_rollout_ref.actor.model.lora_rank = 16
config.actor_rollout_ref.actor.model.lora_alpha = 64
Learning Rate
LoRA works well with higher learning rates than full fine-tuning:
# Full fine-tuning
config.actor_rollout_ref.actor.optim.lr = 1e-6
# LoRA fine-tuning
config.actor_rollout_ref.actor.optim.lr = 5e-5 # 50x higher!
Start with lr=5e-5 for LoRA and reduce if training is unstable.
Memory Savings
LoRA dramatically reduces memory requirements:
| Model Size | Full Fine-Tuning | LoRA (rank=16) | Memory Savings |
|---|
| 7B | ~28 GB | ~12 GB | 57% |
| 13B | ~52 GB | ~20 GB | 62% |
| 70B | ~280 GB | ~100 GB | 64% |
These are approximate values for single-GPU training with batch size 1. Actual memory usage depends on batch size, sequence length, and other factors.
Combining LoRA with Other Techniques
LoRA + Gradient Checkpointing
Maximize memory savings:
config.actor_rollout_ref.actor.model.enable_lora = True
config.actor_rollout_ref.actor.model.lora_rank = 16
config.critic.model.enable_gradient_checkpointing = True
LoRA + Quantization (QLoRA)
Train even larger models:
config.actor_rollout_ref.actor.model.enable_lora = True
config.actor_rollout_ref.actor.model.lora_rank = 16
config.actor_rollout_ref.actor.model.load_in_4bit = True
config.actor_rollout_ref.actor.model.bnb_4bit_compute_dtype = "bfloat16"
LoRA + Multi-GPU
Scale across GPUs:
config.actor_rollout_ref.actor.model.enable_lora = True
config.actor_rollout_ref.actor.model.lora_rank = 16
config.actor_rollout_ref.actor.ppo_mini_batch_size = 256
config.actor_rollout_ref.actor.ppo_micro_batch_size = 64 # Per GPU
config.rllm.workflow.n_parallel_tasks = 512
Saving and Loading LoRA Models
Saving
LoRA adapters are saved automatically during training:
trainer.train() # Checkpoints saved to output_dir
Checkpoint structure:
checkpoints/
├── actor/
│ ├── adapter_config.json
│ ├── adapter_model.bin # LoRA weights (~10-50 MB)
│ └── ...
└── critic/
└── ...
Loading for Inference
from transformers import AutoModelForCausalLM
from peft import PeftModel
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct"
)
# Load LoRA adapter
model = PeftModel.from_pretrained(
base_model,
"checkpoints/actor/",
)
# Merge for faster inference (optional)
model = model.merge_and_unload()
Loading for Continued Training
config.actor_rollout_ref.actor.model.path = "meta-llama/Llama-3.1-8B-Instruct"
config.actor_rollout_ref.actor.model.enable_lora = True
config.actor_rollout_ref.actor.model.lora_checkpoint = "checkpoints/actor/"
trainer = AgentTrainer(
agent_class=MathAgent,
env_class=SingleTurnEnvironment,
config=config,
train_dataset=train_dataset,
)
trainer.train() # Continues from LoRA checkpoint
Data Preparation
LoRA works with the same data format as full fine-tuning:
from datasets import load_dataset
from rllm.data.dataset import DatasetRegistry
# Load raw dataset
gsm8k = load_dataset("openai/gsm8k", "main")
# Preprocess
def preprocess_fn(example, idx):
return {
"question": example["question"],
"ground_truth": extract_answer(example["answer"]),
"data_source": "gsm8k",
}
train_dataset = gsm8k["train"].map(preprocess_fn, with_indices=True)
# Register
train_dataset = DatasetRegistry.register_dataset(
"gsm8k",
train_dataset,
"train"
)
From examples/gsm8k_lora/prepare_gsm8k_data.py:17.
Best Practices
- Start with rank=16: Good balance for most tasks
- Use alpha=2*rank: Standard scaling factor
- Increase LR: LoRA works well with higher learning rates (5e-5)
- Monitor loss: If training is unstable, reduce rank or alpha
- Save adapters: Much smaller than full checkpoints
- Test on validation: LoRA can overfit faster than full fine-tuning
Common Issues
Training Loss Not Decreasing
- Increase
lora_rank (e.g., from 8 to 16)
- Increase
lora_alpha (e.g., from 16 to 32)
- Increase learning rate (e.g., from 1e-5 to 5e-5)
- Add more target modules (include MLP layers)
Out of Memory
- Reduce
lora_rank
- Enable gradient checkpointing
- Reduce batch size
- Use QLoRA (4-bit quantization)
LoRA Not Loading
- Check adapter files exist in checkpoint directory
- Verify
adapter_config.json matches your configuration
- Ensure base model path is correct
- Try loading with PEFT directly to debug
| Method | Training Time | Memory | Checkpoint Size | Performance |
|---|
| Full Fine-Tuning | 1.0x | 28 GB | 14 GB | 100% |
| LoRA (rank=8) | 0.6x | 12 GB | 20 MB | 95% |
| LoRA (rank=16) | 0.7x | 14 GB | 40 MB | 97% |
| LoRA (rank=32) | 0.8x | 16 GB | 80 MB | 98% |
Performance relative to full fine-tuning varies by task. For most tasks, LoRA with rank=16 achieves 95-99% of full fine-tuning performance.
Next Steps