Skip to main content
PufferRL is PufferLib’s built-in training system that implements Proximal Policy Optimization (PPO) with advanced features like V-trace, prioritized experience replay, and automatic mixed precision.

Overview

The PuffeRL class handles the complete training loop including:
  • Environment rollouts with vectorized execution
  • PPO with clipped surrogate objective
  • V-trace for off-policy correction
  • Prioritized experience replay
  • Gradient accumulation and automatic mixed precision
  • LSTM support for recurrent policies
  • Distributed training with PyTorch DDP

Basic usage

1

Load configuration

from pufferlib import pufferl

args = pufferl.load_config('puffer_breakout')
2

Create environment and policy

vecenv = pufferl.load_env('puffer_breakout', args)
policy = pufferl.load_policy(args, vecenv, 'puffer_breakout')
3

Initialize trainer

train_config = {**args['train'], 'env': 'puffer_breakout'}
trainer = pufferl.PuffeRL(train_config, vecenv, policy)
4

Run training loop

while trainer.global_step < train_config['total_timesteps']:
    trainer.evaluate()  # Collect experience
    logs = trainer.train()  # Update policy

trainer.close()

Complete training example

import torch
import pufferlib.vector
import pufferlib.ocean
from pufferlib import pufferl

env_name = 'puffer_breakout'
args = pufferl.load_config(env_name)

# Customize configuration
args['vec']['num_envs'] = 2
args['env']['num_envs'] = 2048
args['policy']['hidden_size'] = 256
args['rnn']['input_size'] = 256
args['rnn']['hidden_size'] = 256
args['train']['total_timesteps'] = 10_000_000
args['train']['learning_rate'] = 0.03

vecenv = pufferl.load_env(env_name, args)
policy = pufferl.load_policy(args, vecenv, env_name)

trainer = pufferl.PuffeRL(args['train'], vecenv, policy)

while trainer.epoch < trainer.total_epochs:
    trainer.evaluate()
    logs = trainer.train()

trainer.print_dashboard()
trainer.close()

PufferRL class reference

Constructor

PuffeRL(config, vecenv, policy, logger=None)
config
dict
required
Training configuration dictionary. See configuration for all available options.
vecenv
VectorEnv
required
Vectorized environment created with pufferlib.vector.make().
policy
nn.Module
required
PyTorch policy network. Must implement forward() and forward_eval() methods.
logger
Logger
default:"None"
Optional logger for tracking metrics. Supports WandbLogger and NeptuneLogger.

Methods

evaluate()

Collects experience from the environment by running the policy.
stats = trainer.evaluate()
Fills internal buffers with observations, actions, rewards, values, and log probabilities. Returns environment statistics.

train()

Performs PPO updates on collected experience.
logs = trainer.train()
Returns a dictionary of training metrics including:
  • policy_loss - Clipped PPO policy loss
  • value_loss - Clipped value function loss
  • entropy - Policy entropy
  • approx_kl - Approximate KL divergence
  • clipfrac - Fraction of clipped policy updates
  • explained_variance - Quality of value predictions

close()

Saves final checkpoint and closes resources.
model_path = trainer.close()
Returns path to the saved model file.

save_checkpoint()

Saves model and optimizer state.
model_path = trainer.save_checkpoint()
Displays training progress in the terminal.
trainer.print_dashboard()

Properties

global_step

Total environment steps collected across all processes.
steps = trainer.global_step

epoch

Current training epoch (number of train() calls).
epoch = trainer.epoch

sps

Steps per second (throughput).
sps = trainer.sps

uptime

Total training time in seconds.
uptime = trainer.uptime

PPO implementation details

PufferRL implements PPO with several enhancements:

Clipped surrogate objective

The policy loss uses the standard PPO clipping:
ratio = new_logprob.exp() / old_logprob.exp()
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()

V-trace correction

Off-policy correction for better sample efficiency:
advantages = compute_puff_advantage(
    values, rewards, terminals, ratio, advantages,
    gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip
)

Prioritized experience replay

Samples experience based on advantage magnitude:
adv = advantages.abs().sum(axis=1)
prio_weights = adv**prio_alpha
prio_probs = prio_weights / prio_weights.sum()
idx = torch.multinomial(prio_probs, minibatch_segments)

Value function clipping

Clipped value loss for stable training:
v_clipped = old_values + torch.clamp(new_values - old_values, -vf_clip, vf_clip)
v_loss_unclipped = (new_values - returns) ** 2
v_loss_clipped = (v_clipped - returns) ** 2
v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()

Distributed training

PufferRL supports distributed training with PyTorch DDP:
torchrun --standalone --nnodes=1 --nproc-per-node=6 \
  -m pufferlib.pufferl train puffer_breakout
The training loop automatically:
  • Initializes process groups
  • Wraps the policy with DistributedDataParallel
  • Synchronizes gradients across processes
  • Aggregates metrics from all workers
Only rank 0 saves checkpoints and logs metrics to avoid conflicts.

LSTM support

For recurrent policies, PufferRL maintains hidden states per environment:
if config['use_rnn']:
    n = vecenv.agents_per_batch
    h = policy.hidden_size
    self.lstm_h = {i*n: torch.zeros(n, h, device=device) 
                   for i in range(total_agents//n)}
    self.lstm_c = {i*n: torch.zeros(n, h, device=device) 
                   for i in range(total_agents//n)}
Hidden states are automatically managed during rollouts and reset on episode boundaries.

Automatic mixed precision

PufferRL uses automatic mixed precision (AMP) for faster training:
if config['device'] == 'cuda':
    self.amp_context = torch.amp.autocast(
        device_type='cuda', 
        dtype=getattr(torch, config['precision'])
    )
Supported precisions:
  • float32 - Full precision (default)
  • bfloat16 - Brain floating point (faster on modern GPUs)

Torch compile

Enable torch.compile for additional speedups:
args['train']['compile'] = True
args['train']['compile_mode'] = 'max-autotune-no-cudagraphs'
This compiles the policy and sampling functions with TorchDynamo.

Checkpointing

Checkpoints are saved automatically based on checkpoint_interval:
if self.epoch % config['checkpoint_interval'] == 0:
    self.save_checkpoint()
Checkpoint files include:
  • Model state dict: model_{env}_{epoch:06d}.pt
  • Training state: trainer_state.pt (optimizer, global_step, epoch)
To resume training, load a checkpoint:
args['load_model_path'] = 'experiments/puffer_breakout_12345/model_puffer_breakout_000200.pt'
policy = pufferl.load_policy(args, vecenv, env_name)

Build docs developers (and LLMs) love