PufferLib doesn’t enforce a strict policy interface. You can use any PyTorch module that returns actions and values. This guide shows you how to create custom policies compatible with PufferRL.
Policy interface
Policies must implement two methods:
forward(observations, state) - Training forward pass
forward_eval(observations, state) - Inference forward pass
Both methods should return (logits, values) where:
logits - Action logits (or torch.distributions.Normal for continuous actions)
values - State value predictions
Basic policy example
Here’s a simple feedforward policy:
import torch
import torch.nn as nn
import pufferlib.pytorch
class Policy(nn.Module):
def __init__(self, env, hidden_size=128):
super().__init__()
self.net = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Linear(
env.single_observation_space.shape[0], hidden_size)),
nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
)
self.action_head = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
self.value_head = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, 1), std=1)
def forward_eval(self, observations, state=None):
hidden = self.net(observations)
logits = self.action_head(hidden)
values = self.value_head(hidden)
return logits, values
def forward(self, observations, state=None):
return self.forward_eval(observations, state)
Default policy
PufferLib provides a default policy that handles multiple action space types:
from pufferlib.models import Default
class Default(nn.Module):
def __init__(self, env, hidden_size=128):
super().__init__()
self.hidden_size = hidden_size
self.is_multidiscrete = isinstance(env.single_action_space,
pufferlib.spaces.MultiDiscrete)
self.is_continuous = isinstance(env.single_action_space,
pufferlib.spaces.Box)
# Encoder
num_obs = np.prod(env.single_observation_space.shape)
self.encoder = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Linear(num_obs, hidden_size)),
nn.GELU(),
)
# Decoder for discrete actions
if not self.is_continuous:
num_atns = env.single_action_space.n
self.decoder = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, num_atns), std=0.01)
# Decoder for continuous actions
else:
self.decoder_mean = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, env.single_action_space.shape[0]),
std=0.01)
self.decoder_logstd = nn.Parameter(
torch.zeros(1, env.single_action_space.shape[0]))
self.value = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, 1), std=1)
def forward_eval(self, observations, state=None):
hidden = self.encode_observations(observations, state=state)
logits, values = self.decode_actions(hidden)
return logits, values
def forward(self, observations, state=None):
return self.forward_eval(observations, state)
def encode_observations(self, observations, state=None):
batch_size = observations.shape[0]
observations = observations.view(batch_size, -1)
return self.encoder(observations.float())
def decode_actions(self, hidden):
if self.is_continuous:
mean = self.decoder_mean(hidden)
logstd = self.decoder_logstd.expand_as(mean)
std = torch.exp(logstd)
logits = torch.distributions.Normal(mean, std)
else:
logits = self.decoder(hidden)
values = self.value(hidden)
return logits, values
Structured policies with encode/decode
For LSTM compatibility, structure your policy with separate encoding and decoding:
class StructuredPolicy(nn.Module):
def __init__(self, env, hidden_size=256):
super().__init__()
self.hidden_size = hidden_size
# Encoder: observations → hidden
self.encoder = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Linear(
env.single_observation_space.shape[0], hidden_size)),
nn.GELU(),
)
# Decoder: hidden → actions + values
self.action_head = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
self.value_head = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, 1), std=1)
def encode_observations(self, observations, state=None):
"""Encode observations to hidden state."""
batch_size = observations.shape[0]
observations = observations.view(batch_size, -1)
return self.encoder(observations.float())
def decode_actions(self, hidden):
"""Decode hidden state to actions and values."""
logits = self.action_head(hidden)
values = self.value_head(hidden)
return logits, values
def forward_eval(self, observations, state=None):
hidden = self.encode_observations(observations, state)
return self.decode_actions(hidden)
def forward(self, observations, state=None):
return self.forward_eval(observations, state)
Convolutional policy
For image observations (e.g., Atari):
class Convolutional(nn.Module):
def __init__(self, env, framestack=4, flat_size=3136,
hidden_size=512, channels_last=False):
super().__init__()
self.channels_last = channels_last
self.hidden_size = hidden_size
self.is_continuous = False
self.network = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Conv2d(framestack, 32, 8, stride=4)),
nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Conv2d(32, 64, 4, stride=2)),
nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Conv2d(64, 64, 3, stride=1)),
nn.ReLU(),
nn.Flatten(),
pufferlib.pytorch.layer_init(nn.Linear(flat_size, hidden_size)),
nn.ReLU(),
)
self.actor = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
self.value_fn = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, 1), std=1)
def encode_observations(self, observations, state=None):
if self.channels_last:
observations = observations.permute(0, 3, 1, 2)
return self.network(observations.float() / 255.0)
def decode_actions(self, hidden):
action = self.actor(hidden)
value = self.value_fn(hidden)
return action, value
def forward(self, observations, state=None):
hidden = self.encode_observations(observations)
return self.decode_actions(hidden)
def forward_eval(self, observations, state=None):
return self.forward(observations, state)
LSTM wrapper
Wrap any structured policy with an LSTM:
from pufferlib.models import LSTMWrapper
# Create base policy
base_policy = StructuredPolicy(env, hidden_size=512)
# Wrap with LSTM
policy = LSTMWrapper(env, base_policy,
input_size=512,
hidden_size=512)
The LSTM wrapper automatically manages hidden states during training and inference.
LSTM wrapper implementation
Here’s how the LSTM wrapper works:
class LSTMWrapper(nn.Module):
def __init__(self, env, policy, input_size=128, hidden_size=128):
super().__init__()
self.obs_shape = env.single_observation_space.shape
self.policy = policy
self.input_size = input_size
self.hidden_size = hidden_size
self.is_continuous = policy.is_continuous
# Initialize LSTM
self.lstm = nn.LSTM(input_size, hidden_size)
# Create LSTM cell for fast inference
self.cell = nn.LSTMCell(input_size, hidden_size)
self.cell.weight_ih = self.lstm.weight_ih_l0
self.cell.weight_hh = self.lstm.weight_hh_l0
self.cell.bias_ih = self.lstm.bias_ih_l0
self.cell.bias_hh = self.lstm.bias_hh_l0
def forward_eval(self, observations, state):
"""Fast inference with LSTMCell."""
hidden = self.policy.encode_observations(observations, state=state)
h = state['lstm_h']
c = state['lstm_c']
if h is not None:
lstm_state = (h, c)
else:
lstm_state = None
hidden, c = self.cell(hidden, lstm_state)
state['lstm_h'] = hidden
state['lstm_c'] = c
logits, values = self.policy.decode_actions(hidden)
return logits, values
def forward(self, observations, state):
"""Training forward with batched LSTM."""
x = observations
lstm_h = state['lstm_h']
lstm_c = state['lstm_c']
x_shape, space_shape = x.shape, self.obs_shape
x_n, space_n = len(x_shape), len(space_shape)
# Handle batch and time dimensions
if x_n == space_n + 1:
B, TT = x_shape[0], 1
elif x_n == space_n + 2:
B, TT = x_shape[:2]
else:
raise ValueError('Invalid input tensor shape', x.shape)
if lstm_h is not None:
lstm_state = (lstm_h, lstm_c)
else:
lstm_state = None
# Encode observations
x = x.reshape(B*TT, *space_shape)
hidden = self.policy.encode_observations(x, state)
hidden = hidden.reshape(B, TT, self.input_size)
# LSTM forward
hidden = hidden.transpose(0, 1)
hidden, (lstm_h, lstm_c) = self.lstm.forward(hidden, lstm_state)
hidden = hidden.transpose(0, 1)
# Decode actions
flat_hidden = hidden.reshape(B*TT, self.hidden_size)
logits, values = self.policy.decode_actions(flat_hidden)
values = values.reshape(B, TT)
state['lstm_h'] = lstm_h.detach()
state['lstm_c'] = lstm_c.detach()
return logits, values
Using LSTM wrapper
Structure your policy
Split your policy into encode_observations() and decode_actions() methods.
Wrap with LSTM
from pufferlib.models import LSTMWrapper
base_policy = MyPolicy(env, hidden_size=512)
policy = LSTMWrapper(env, base_policy,
input_size=512, hidden_size=512)
Configure training
args['rnn_name'] = 'LSTMWrapper'
args['train']['use_rnn'] = True
args['train']['bptt_horizon'] = 64 # Sequence length
Continuous action spaces
For continuous actions, return a torch.distributions.Normal:
class ContinuousPolicy(nn.Module):
def __init__(self, env, hidden_size=256):
super().__init__()
self.is_continuous = True
self.encoder = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Linear(
env.single_observation_space.shape[0], hidden_size)),
nn.Tanh(),
pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
nn.Tanh(),
)
action_dim = env.single_action_space.shape[0]
self.mean_head = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, action_dim), std=0.01)
self.logstd = nn.Parameter(torch.zeros(1, action_dim))
self.value_head = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, 1), std=1)
def forward_eval(self, observations, state=None):
hidden = self.encoder(observations.float())
mean = self.mean_head(hidden)
std = torch.exp(self.logstd.expand_as(mean))
logits = torch.distributions.Normal(mean, std)
values = self.value_head(hidden)
return logits, values
def forward(self, observations, state=None):
return self.forward_eval(observations, state)
Multi-discrete action spaces
For multi-discrete actions:
class MultiDiscretePolicy(nn.Module):
def __init__(self, env, hidden_size=256):
super().__init__()
self.is_multidiscrete = True
self.action_nvec = tuple(env.single_action_space.nvec)
self.encoder = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Linear(
env.single_observation_space.shape[0], hidden_size)),
nn.GELU(),
)
num_actions = sum(self.action_nvec)
self.action_head = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, num_actions), std=0.01)
self.value_head = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, 1), std=1)
def forward_eval(self, observations, state=None):
hidden = self.encoder(observations.float())
# Split into multiple discrete distributions
logits = self.action_head(hidden).split(self.action_nvec, dim=1)
values = self.value_head(hidden)
return logits, values
def forward(self, observations, state=None):
return self.forward_eval(observations, state)
Layer initialization
Use pufferlib.pytorch.layer_init() for proper initialization:
import pufferlib.pytorch
# Default orthogonal initialization (std=sqrt(2))
layer = pufferlib.pytorch.layer_init(nn.Linear(128, 64))
# Custom standard deviation
action_head = pufferlib.pytorch.layer_init(nn.Linear(64, 4), std=0.01)
value_head = pufferlib.pytorch.layer_init(nn.Linear(64, 1), std=1.0)
This applies orthogonal initialization to weights and zeros to biases.
Using custom policies
In Python
from pufferlib import pufferl
args = pufferl.load_config('puffer_breakout')
vecenv = pufferl.load_env('puffer_breakout', args)
# Use custom policy
policy = MyCustomPolicy(vecenv.driver_env, hidden_size=512).cuda()
trainer = pufferl.PuffeRL(args['train'], vecenv, policy)
Via configuration
Create a policy module:
import torch.nn as nn
import pufferlib.pytorch
class MyPolicy(nn.Module):
def __init__(self, env, hidden_size=256):
# ... implementation ...
pass
Reference it in your config:
[base]
package = my_package
env_name = my_env
policy_name = MyPolicy
[policy]
hidden_size = 512
Best practices
Use encode/decode structure
Split policies into encode_observations() and decode_actions() for LSTM compatibility.
Initialize layers properly
Use pufferlib.pytorch.layer_init() with appropriate std values (0.01 for action heads, 1.0 for value heads).
Match forward signatures
Both forward() and forward_eval() should accept (observations, state=None) and return (logits, values).
Handle observation shapes
Flatten observations appropriately for your architecture.
Set is_continuous flag
Set self.is_continuous = True for continuous action spaces.
For continuous actions, you must return a torch.distributions.Normal distribution, not raw tensors.