Skip to main content
PufferLib doesn’t enforce a strict policy interface. You can use any PyTorch module that returns actions and values. This guide shows you how to create custom policies compatible with PufferRL.

Policy interface

Policies must implement two methods:
  • forward(observations, state) - Training forward pass
  • forward_eval(observations, state) - Inference forward pass
Both methods should return (logits, values) where:
  • logits - Action logits (or torch.distributions.Normal for continuous actions)
  • values - State value predictions

Basic policy example

Here’s a simple feedforward policy:
import torch
import torch.nn as nn
import pufferlib.pytorch

class Policy(nn.Module):
    def __init__(self, env, hidden_size=128):
        super().__init__()
        self.net = nn.Sequential(
            pufferlib.pytorch.layer_init(nn.Linear(
                env.single_observation_space.shape[0], hidden_size)),
            nn.ReLU(),
            pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
        )
        self.action_head = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
        self.value_head = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, 1), std=1)

    def forward_eval(self, observations, state=None):
        hidden = self.net(observations)
        logits = self.action_head(hidden)
        values = self.value_head(hidden)
        return logits, values

    def forward(self, observations, state=None):
        return self.forward_eval(observations, state)

Default policy

PufferLib provides a default policy that handles multiple action space types:
from pufferlib.models import Default

class Default(nn.Module):
    def __init__(self, env, hidden_size=128):
        super().__init__()
        self.hidden_size = hidden_size
        self.is_multidiscrete = isinstance(env.single_action_space,
                pufferlib.spaces.MultiDiscrete)
        self.is_continuous = isinstance(env.single_action_space,
                pufferlib.spaces.Box)

        # Encoder
        num_obs = np.prod(env.single_observation_space.shape)
        self.encoder = nn.Sequential(
            pufferlib.pytorch.layer_init(nn.Linear(num_obs, hidden_size)),
            nn.GELU(),
        )
        
        # Decoder for discrete actions
        if not self.is_continuous:
            num_atns = env.single_action_space.n
            self.decoder = pufferlib.pytorch.layer_init(
                nn.Linear(hidden_size, num_atns), std=0.01)
        # Decoder for continuous actions
        else:
            self.decoder_mean = pufferlib.pytorch.layer_init(
                nn.Linear(hidden_size, env.single_action_space.shape[0]), 
                std=0.01)
            self.decoder_logstd = nn.Parameter(
                torch.zeros(1, env.single_action_space.shape[0]))

        self.value = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, 1), std=1)

    def forward_eval(self, observations, state=None):
        hidden = self.encode_observations(observations, state=state)
        logits, values = self.decode_actions(hidden)
        return logits, values

    def forward(self, observations, state=None):
        return self.forward_eval(observations, state)

    def encode_observations(self, observations, state=None):
        batch_size = observations.shape[0]
        observations = observations.view(batch_size, -1)
        return self.encoder(observations.float())

    def decode_actions(self, hidden):
        if self.is_continuous:
            mean = self.decoder_mean(hidden)
            logstd = self.decoder_logstd.expand_as(mean)
            std = torch.exp(logstd)
            logits = torch.distributions.Normal(mean, std)
        else:
            logits = self.decoder(hidden)

        values = self.value(hidden)
        return logits, values

Structured policies with encode/decode

For LSTM compatibility, structure your policy with separate encoding and decoding:
class StructuredPolicy(nn.Module):
    def __init__(self, env, hidden_size=256):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Encoder: observations → hidden
        self.encoder = nn.Sequential(
            pufferlib.pytorch.layer_init(nn.Linear(
                env.single_observation_space.shape[0], hidden_size)),
            nn.GELU(),
        )
        
        # Decoder: hidden → actions + values
        self.action_head = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
        self.value_head = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, 1), std=1)

    def encode_observations(self, observations, state=None):
        """Encode observations to hidden state."""
        batch_size = observations.shape[0]
        observations = observations.view(batch_size, -1)
        return self.encoder(observations.float())

    def decode_actions(self, hidden):
        """Decode hidden state to actions and values."""
        logits = self.action_head(hidden)
        values = self.value_head(hidden)
        return logits, values

    def forward_eval(self, observations, state=None):
        hidden = self.encode_observations(observations, state)
        return self.decode_actions(hidden)

    def forward(self, observations, state=None):
        return self.forward_eval(observations, state)

Convolutional policy

For image observations (e.g., Atari):
class Convolutional(nn.Module):
    def __init__(self, env, framestack=4, flat_size=3136,
                 hidden_size=512, channels_last=False):
        super().__init__()
        self.channels_last = channels_last
        self.hidden_size = hidden_size
        self.is_continuous = False

        self.network = nn.Sequential(
            pufferlib.pytorch.layer_init(nn.Conv2d(framestack, 32, 8, stride=4)),
            nn.ReLU(),
            pufferlib.pytorch.layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            pufferlib.pytorch.layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            pufferlib.pytorch.layer_init(nn.Linear(flat_size, hidden_size)),
            nn.ReLU(),
        )
        self.actor = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
        self.value_fn = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, 1), std=1)

    def encode_observations(self, observations, state=None):
        if self.channels_last:
            observations = observations.permute(0, 3, 1, 2)
        return self.network(observations.float() / 255.0)

    def decode_actions(self, hidden):
        action = self.actor(hidden)
        value = self.value_fn(hidden)
        return action, value

    def forward(self, observations, state=None):
        hidden = self.encode_observations(observations)
        return self.decode_actions(hidden)

    def forward_eval(self, observations, state=None):
        return self.forward(observations, state)

LSTM wrapper

Wrap any structured policy with an LSTM:
from pufferlib.models import LSTMWrapper

# Create base policy
base_policy = StructuredPolicy(env, hidden_size=512)

# Wrap with LSTM
policy = LSTMWrapper(env, base_policy, 
                     input_size=512, 
                     hidden_size=512)
The LSTM wrapper automatically manages hidden states during training and inference.

LSTM wrapper implementation

Here’s how the LSTM wrapper works:
class LSTMWrapper(nn.Module):
    def __init__(self, env, policy, input_size=128, hidden_size=128):
        super().__init__()
        self.obs_shape = env.single_observation_space.shape
        self.policy = policy
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.is_continuous = policy.is_continuous

        # Initialize LSTM
        self.lstm = nn.LSTM(input_size, hidden_size)
        
        # Create LSTM cell for fast inference
        self.cell = nn.LSTMCell(input_size, hidden_size)
        self.cell.weight_ih = self.lstm.weight_ih_l0
        self.cell.weight_hh = self.lstm.weight_hh_l0
        self.cell.bias_ih = self.lstm.bias_ih_l0
        self.cell.bias_hh = self.lstm.bias_hh_l0

    def forward_eval(self, observations, state):
        """Fast inference with LSTMCell."""
        hidden = self.policy.encode_observations(observations, state=state)
        h = state['lstm_h']
        c = state['lstm_c']

        if h is not None:
            lstm_state = (h, c)
        else:
            lstm_state = None

        hidden, c = self.cell(hidden, lstm_state)
        state['lstm_h'] = hidden
        state['lstm_c'] = c
        
        logits, values = self.policy.decode_actions(hidden)
        return logits, values

    def forward(self, observations, state):
        """Training forward with batched LSTM."""
        x = observations
        lstm_h = state['lstm_h']
        lstm_c = state['lstm_c']

        x_shape, space_shape = x.shape, self.obs_shape
        x_n, space_n = len(x_shape), len(space_shape)

        # Handle batch and time dimensions
        if x_n == space_n + 1:
            B, TT = x_shape[0], 1
        elif x_n == space_n + 2:
            B, TT = x_shape[:2]
        else:
            raise ValueError('Invalid input tensor shape', x.shape)

        if lstm_h is not None:
            lstm_state = (lstm_h, lstm_c)
        else:
            lstm_state = None

        # Encode observations
        x = x.reshape(B*TT, *space_shape)
        hidden = self.policy.encode_observations(x, state)
        hidden = hidden.reshape(B, TT, self.input_size)

        # LSTM forward
        hidden = hidden.transpose(0, 1)
        hidden, (lstm_h, lstm_c) = self.lstm.forward(hidden, lstm_state)
        hidden = hidden.transpose(0, 1)

        # Decode actions
        flat_hidden = hidden.reshape(B*TT, self.hidden_size)
        logits, values = self.policy.decode_actions(flat_hidden)
        values = values.reshape(B, TT)
        
        state['lstm_h'] = lstm_h.detach()
        state['lstm_c'] = lstm_c.detach()
        return logits, values

Using LSTM wrapper

1

Structure your policy

Split your policy into encode_observations() and decode_actions() methods.
2

Wrap with LSTM

from pufferlib.models import LSTMWrapper

base_policy = MyPolicy(env, hidden_size=512)
policy = LSTMWrapper(env, base_policy, 
                     input_size=512, hidden_size=512)
3

Configure training

args['rnn_name'] = 'LSTMWrapper'
args['train']['use_rnn'] = True
args['train']['bptt_horizon'] = 64  # Sequence length

Continuous action spaces

For continuous actions, return a torch.distributions.Normal:
class ContinuousPolicy(nn.Module):
    def __init__(self, env, hidden_size=256):
        super().__init__()
        self.is_continuous = True
        
        self.encoder = nn.Sequential(
            pufferlib.pytorch.layer_init(nn.Linear(
                env.single_observation_space.shape[0], hidden_size)),
            nn.Tanh(),
            pufferlib.pytorch.layer_init(nn.Linear(hidden_size, hidden_size)),
            nn.Tanh(),
        )
        
        action_dim = env.single_action_space.shape[0]
        self.mean_head = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, action_dim), std=0.01)
        self.logstd = nn.Parameter(torch.zeros(1, action_dim))
        self.value_head = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, 1), std=1)

    def forward_eval(self, observations, state=None):
        hidden = self.encoder(observations.float())
        
        mean = self.mean_head(hidden)
        std = torch.exp(self.logstd.expand_as(mean))
        logits = torch.distributions.Normal(mean, std)
        
        values = self.value_head(hidden)
        return logits, values

    def forward(self, observations, state=None):
        return self.forward_eval(observations, state)

Multi-discrete action spaces

For multi-discrete actions:
class MultiDiscretePolicy(nn.Module):
    def __init__(self, env, hidden_size=256):
        super().__init__()
        self.is_multidiscrete = True
        self.action_nvec = tuple(env.single_action_space.nvec)
        
        self.encoder = nn.Sequential(
            pufferlib.pytorch.layer_init(nn.Linear(
                env.single_observation_space.shape[0], hidden_size)),
            nn.GELU(),
        )
        
        num_actions = sum(self.action_nvec)
        self.action_head = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, num_actions), std=0.01)
        self.value_head = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, 1), std=1)

    def forward_eval(self, observations, state=None):
        hidden = self.encoder(observations.float())
        
        # Split into multiple discrete distributions
        logits = self.action_head(hidden).split(self.action_nvec, dim=1)
        values = self.value_head(hidden)
        
        return logits, values

    def forward(self, observations, state=None):
        return self.forward_eval(observations, state)

Layer initialization

Use pufferlib.pytorch.layer_init() for proper initialization:
import pufferlib.pytorch

# Default orthogonal initialization (std=sqrt(2))
layer = pufferlib.pytorch.layer_init(nn.Linear(128, 64))

# Custom standard deviation
action_head = pufferlib.pytorch.layer_init(nn.Linear(64, 4), std=0.01)
value_head = pufferlib.pytorch.layer_init(nn.Linear(64, 1), std=1.0)
This applies orthogonal initialization to weights and zeros to biases.

Using custom policies

In Python

from pufferlib import pufferl

args = pufferl.load_config('puffer_breakout')
vecenv = pufferl.load_env('puffer_breakout', args)

# Use custom policy
policy = MyCustomPolicy(vecenv.driver_env, hidden_size=512).cuda()

trainer = pufferl.PuffeRL(args['train'], vecenv, policy)

Via configuration

Create a policy module:
custom_policies.py
import torch.nn as nn
import pufferlib.pytorch

class MyPolicy(nn.Module):
    def __init__(self, env, hidden_size=256):
        # ... implementation ...
        pass
Reference it in your config:
[base]
package = my_package
env_name = my_env
policy_name = MyPolicy

[policy]
hidden_size = 512

Best practices

1

Use encode/decode structure

Split policies into encode_observations() and decode_actions() for LSTM compatibility.
2

Initialize layers properly

Use pufferlib.pytorch.layer_init() with appropriate std values (0.01 for action heads, 1.0 for value heads).
3

Match forward signatures

Both forward() and forward_eval() should accept (observations, state=None) and return (logits, values).
4

Handle observation shapes

Flatten observations appropriately for your architecture.
5

Set is_continuous flag

Set self.is_continuous = True for continuous action spaces.
For continuous actions, you must return a torch.distributions.Normal distribution, not raw tensors.

Build docs developers (and LLMs) love