PufferRL is PufferLib’s built-in training system that implements Proximal Policy Optimization (PPO) with advanced features like V-trace, prioritized experience replay, and automatic mixed precision.
Overview
The PuffeRL class handles the complete training loop including:
- Environment rollouts with vectorized execution
- PPO with clipped surrogate objective
- V-trace for off-policy correction
- Prioritized experience replay
- Gradient accumulation and automatic mixed precision
- LSTM support for recurrent policies
- Distributed training with PyTorch DDP
Basic usage
Load configuration
from pufferlib import pufferl
args = pufferl.load_config('puffer_breakout')
Create environment and policy
vecenv = pufferl.load_env('puffer_breakout', args)
policy = pufferl.load_policy(args, vecenv, 'puffer_breakout')
Initialize trainer
train_config = {**args['train'], 'env': 'puffer_breakout'}
trainer = pufferl.PuffeRL(train_config, vecenv, policy)
Run training loop
while trainer.global_step < train_config['total_timesteps']:
trainer.evaluate() # Collect experience
logs = trainer.train() # Update policy
trainer.close()
Complete training example
import torch
import pufferlib.vector
import pufferlib.ocean
from pufferlib import pufferl
env_name = 'puffer_breakout'
args = pufferl.load_config(env_name)
# Customize configuration
args['vec']['num_envs'] = 2
args['env']['num_envs'] = 2048
args['policy']['hidden_size'] = 256
args['rnn']['input_size'] = 256
args['rnn']['hidden_size'] = 256
args['train']['total_timesteps'] = 10_000_000
args['train']['learning_rate'] = 0.03
vecenv = pufferl.load_env(env_name, args)
policy = pufferl.load_policy(args, vecenv, env_name)
trainer = pufferl.PuffeRL(args['train'], vecenv, policy)
while trainer.epoch < trainer.total_epochs:
trainer.evaluate()
logs = trainer.train()
trainer.print_dashboard()
trainer.close()
PufferRL class reference
Constructor
PuffeRL(config, vecenv, policy, logger=None)
Training configuration dictionary. See configuration for all available options.
Vectorized environment created with pufferlib.vector.make().
PyTorch policy network. Must implement forward() and forward_eval() methods.
Optional logger for tracking metrics. Supports WandbLogger and NeptuneLogger.
Methods
evaluate()
Collects experience from the environment by running the policy.
stats = trainer.evaluate()
Fills internal buffers with observations, actions, rewards, values, and log probabilities. Returns environment statistics.
train()
Performs PPO updates on collected experience.
Returns a dictionary of training metrics including:
policy_loss - Clipped PPO policy loss
value_loss - Clipped value function loss
entropy - Policy entropy
approx_kl - Approximate KL divergence
clipfrac - Fraction of clipped policy updates
explained_variance - Quality of value predictions
close()
Saves final checkpoint and closes resources.
model_path = trainer.close()
Returns path to the saved model file.
save_checkpoint()
Saves model and optimizer state.
model_path = trainer.save_checkpoint()
print_dashboard()
Displays training progress in the terminal.
trainer.print_dashboard()
Properties
global_step
Total environment steps collected across all processes.
steps = trainer.global_step
epoch
Current training epoch (number of train() calls).
sps
Steps per second (throughput).
uptime
Total training time in seconds.
PPO implementation details
PufferRL implements PPO with several enhancements:
Clipped surrogate objective
The policy loss uses the standard PPO clipping:
ratio = new_logprob.exp() / old_logprob.exp()
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()
V-trace correction
Off-policy correction for better sample efficiency:
advantages = compute_puff_advantage(
values, rewards, terminals, ratio, advantages,
gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip
)
Prioritized experience replay
Samples experience based on advantage magnitude:
adv = advantages.abs().sum(axis=1)
prio_weights = adv**prio_alpha
prio_probs = prio_weights / prio_weights.sum()
idx = torch.multinomial(prio_probs, minibatch_segments)
Value function clipping
Clipped value loss for stable training:
v_clipped = old_values + torch.clamp(new_values - old_values, -vf_clip, vf_clip)
v_loss_unclipped = (new_values - returns) ** 2
v_loss_clipped = (v_clipped - returns) ** 2
v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
Distributed training
PufferRL supports distributed training with PyTorch DDP:
torchrun --standalone --nnodes=1 --nproc-per-node=6 \
-m pufferlib.pufferl train puffer_breakout
The training loop automatically:
- Initializes process groups
- Wraps the policy with DistributedDataParallel
- Synchronizes gradients across processes
- Aggregates metrics from all workers
Only rank 0 saves checkpoints and logs metrics to avoid conflicts.
LSTM support
For recurrent policies, PufferRL maintains hidden states per environment:
if config['use_rnn']:
n = vecenv.agents_per_batch
h = policy.hidden_size
self.lstm_h = {i*n: torch.zeros(n, h, device=device)
for i in range(total_agents//n)}
self.lstm_c = {i*n: torch.zeros(n, h, device=device)
for i in range(total_agents//n)}
Hidden states are automatically managed during rollouts and reset on episode boundaries.
Automatic mixed precision
PufferRL uses automatic mixed precision (AMP) for faster training:
if config['device'] == 'cuda':
self.amp_context = torch.amp.autocast(
device_type='cuda',
dtype=getattr(torch, config['precision'])
)
Supported precisions:
float32 - Full precision (default)
bfloat16 - Brain floating point (faster on modern GPUs)
Torch compile
Enable torch.compile for additional speedups:
args['train']['compile'] = True
args['train']['compile_mode'] = 'max-autotune-no-cudagraphs'
This compiles the policy and sampling functions with TorchDynamo.
Checkpointing
Checkpoints are saved automatically based on checkpoint_interval:
if self.epoch % config['checkpoint_interval'] == 0:
self.save_checkpoint()
Checkpoint files include:
- Model state dict:
model_{env}_{epoch:06d}.pt
- Training state:
trainer_state.pt (optimizer, global_step, epoch)
To resume training, load a checkpoint:
args['load_model_path'] = 'experiments/puffer_breakout_12345/model_puffer_breakout_000200.pt'
policy = pufferl.load_policy(args, vecenv, env_name)