Skip to main content

Default policy

A simple fully-connected policy that works with any observation and action space.
from pufferlib.models import Default

policy = Default(env, hidden_size=128)

Parameters

env
Environment
required
Environment instance to extract observation and action space information.
hidden_size
int
default:"128"
Size of the hidden layer.

Properties

hidden_size
int
Size of the hidden layer.
is_multidiscrete
bool
Whether the action space is multi-discrete.
is_continuous
bool
Whether the action space is continuous.
is_dict_obs
bool
Whether observations are dictionary-based.

Methods

forward

Forward pass through the policy.
logits, values = policy.forward(observations, state=None)
observations
torch.Tensor
required
Batch of observations from the environment.
state
dict
default:"None"
Optional state dictionary (used by recurrent policies).
logits
torch.Tensor | torch.distributions.Normal | tuple
Action logits (discrete), Normal distribution (continuous), or tuple of logits (multi-discrete).
values
torch.Tensor
Value predictions for each observation.

forward_eval

Forward pass for evaluation (same as forward for Default policy).
logits, values = policy.forward_eval(observations, state=None)
observations
torch.Tensor
required
Batch of observations.
state
dict
default:"None"
Optional state dictionary.
logits
torch.Tensor | torch.distributions.Normal | tuple
Action logits or distribution.
values
torch.Tensor
Value predictions.

encode_observations

Encode observations into hidden states.
hidden = policy.encode_observations(observations, state=None)
observations
torch.Tensor
required
Batch of observations. Automatically flattened if dictionary-based.
state
dict
default:"None"
Optional state dictionary.
hidden
torch.Tensor
Encoded hidden states with shape (batch_size, hidden_size).

decode_actions

Decode hidden states into action logits and values.
logits, values = policy.decode_actions(hidden)
hidden
torch.Tensor
required
Hidden states from encode_observations.
logits
torch.Tensor | torch.distributions.Normal | tuple
Action logits or distribution.
values
torch.Tensor
Value predictions.

LSTMWrapper

Wraps a policy with an LSTM layer for temporal processing.
from pufferlib.models import LSTMWrapper

policy = LSTMWrapper(env, base_policy, input_size=128, hidden_size=128)

Parameters

env
Environment
required
Environment instance.
policy
nn.Module
required
Base policy to wrap. Must implement encode_observations and decode_actions.
input_size
int
default:"128"
Input size to the LSTM (should match base policy hidden size).
hidden_size
int
default:"128"
LSTM hidden state size.

Properties

hidden_size
int
LSTM hidden state size.
input_size
int
Input size to the LSTM.
is_continuous
bool
Whether the action space is continuous (inherited from base policy).

Methods

forward

Forward pass for training with full time sequences.
logits, values = lstm_policy.forward(observations, state)
observations
torch.Tensor
required
Observations with shape (batch_size, time_steps, *obs_shape) or (batch_size, *obs_shape).
state
dict
required
State dictionary containing:
  • lstm_h: LSTM hidden state or None
  • lstm_c: LSTM cell state or None
  • action: Actions (when computing log probs)
logits
torch.Tensor | torch.distributions.Normal | tuple
Action logits or distribution.
values
torch.Tensor
Value predictions with shape (batch_size, time_steps).

forward_eval

Fast forward pass for inference (3x faster than using LSTM directly).
logits, values = lstm_policy.forward_eval(observations, state)
observations
torch.Tensor
required
Single timestep observations with shape (batch_size, *obs_shape).
state
dict
required
State dictionary. Updated in-place with new LSTM states:
  • lstm_h: LSTM hidden state
  • lstm_c: LSTM cell state
  • hidden: Encoded hidden state
logits
torch.Tensor | torch.distributions.Normal | tuple
Action logits or distribution.
values
torch.Tensor
Value predictions.

Convolutional

NatureCNN architecture from CleanRL, designed for Atari and visual observations.
from pufferlib.models import Convolutional

policy = Convolutional(
    env,
    framestack=4,
    flat_size=3136,
    input_size=512,
    hidden_size=512,
    output_size=512,
    channels_last=False,
    downsample=1
)

Parameters

env
Environment
required
Environment instance.
framestack
int
required
Number of frames stacked in observation (1 with LSTM, 4 without).
flat_size
int
required
Size of flattened feature map after convolutions.
input_size
int
default:"512"
Input size to final linear layer.
hidden_size
int
default:"512"
Hidden layer size.
output_size
int
default:"512"
Output size of encoded features.
channels_last
bool
default:"False"
Whether observations are in channels-last format.
downsample
int
default:"1"
Downsampling factor for observations.

Methods

forward

Forward pass through the network.
actions, value = policy.forward(observations, state=None)
observations
torch.Tensor
required
Image observations with shape (batch, channels, height, width).
state
dict
default:"None"
Optional state dictionary.
actions
torch.Tensor
Action logits.
value
torch.Tensor
Value predictions.

encode_observations

Encode image observations into features.
features = policy.encode_observations(observations, state=None)
observations
torch.Tensor
required
Image observations (will be normalized to [0, 1] range).
state
dict
default:"None"
Optional state dictionary.
features
torch.Tensor
Encoded features.

decode_actions

Decode features into actions and values.
action, value = policy.decode_actions(flat_hidden)
flat_hidden
torch.Tensor
required
Flattened hidden features.
action
torch.Tensor
Action logits.
value
torch.Tensor
Value predictions.

ProcgenResnet

ResNet-based architecture from the Impala paper, used for Procgen environments.
from pufferlib.models import ProcgenResnet

policy = ProcgenResnet(env, cnn_width=16, mlp_width=256)

Parameters

env
Environment
required
Environment instance.
cnn_width
int
default:"16"
Width multiplier for convolutional layers.
mlp_width
int
default:"256"
Width of the final MLP layer.

Methods

forward

Forward pass through the network.
actions, value = policy.forward(observations, state=None)
observations
torch.Tensor
required
Image observations with shape (batch, height, width, channels).
state
dict
default:"None"
Optional state dictionary.
actions
torch.Tensor
Action logits.
value
torch.Tensor
Value predictions.

encode_observations

Encode image observations through ResNet blocks.
hidden = policy.encode_observations(x)
x
torch.Tensor
required
Image observations (normalized to [0, 1] internally).
hidden
torch.Tensor
Encoded features.

decode_actions

Decode features into actions and values.
action, value = policy.decode_actions(hidden)
hidden
torch.Tensor
required
Encoded features.
action
torch.Tensor
Action logits.
value
torch.Tensor
Value predictions.

Usage examples

import pufferlib.vector
import pufferlib.models

# Create environment
vecenv = pufferlib.vector.make(
    lambda: gym.make('CartPole-v1'),
    num_envs=8
)

# Create default policy
policy = pufferlib.models.Default(
    vecenv.driver_env,
    hidden_size=128
)

# Use with PuffeRL
from pufferlib.pufferl import PuffeRL
pufferl = PuffeRL(config, vecenv, policy)

Custom policies

You can create custom policies by following the PufferLib convention:
import torch.nn as nn
import pufferlib.pytorch

class CustomPolicy(nn.Module):
    def __init__(self, env, hidden_size=256):
        super().__init__()
        self.hidden_size = hidden_size
        self.is_continuous = False
        
        # Define your architecture
        self.encoder = nn.Sequential(
            pufferlib.pytorch.layer_init(nn.Linear(obs_size, hidden_size)),
            nn.ReLU(),
        )
        self.decoder = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, num_actions), std=0.01
        )
        self.value = pufferlib.pytorch.layer_init(
            nn.Linear(hidden_size, 1), std=1
        )
    
    def forward(self, observations, state=None):
        return self.forward_eval(observations, state)
    
    def forward_eval(self, observations, state=None):
        hidden = self.encode_observations(observations, state)
        return self.decode_actions(hidden)
    
    def encode_observations(self, observations, state=None):
        # Encode observations to hidden states
        return self.encoder(observations)
    
    def decode_actions(self, hidden):
        # Decode to actions and values
        logits = self.decoder(hidden)
        values = self.value(hidden)
        return logits, values
The key requirements are:
  • Implement forward() and forward_eval() methods
  • Return (logits, values) from forward methods
  • For LSTM wrapper compatibility, implement encode_observations() and decode_actions()

Build docs developers (and LLMs) love