Default policy
A simple fully-connected policy that works with any observation and action space.Parameters
Environment instance to extract observation and action space information.
Size of the hidden layer.
Properties
Size of the hidden layer.
Whether the action space is multi-discrete.
Whether the action space is continuous.
Whether observations are dictionary-based.
Methods
forward
Forward pass through the policy.Batch of observations from the environment.
Optional state dictionary (used by recurrent policies).
Action logits (discrete), Normal distribution (continuous), or tuple of logits (multi-discrete).
Value predictions for each observation.
forward_eval
Forward pass for evaluation (same as forward for Default policy).Batch of observations.
Optional state dictionary.
Action logits or distribution.
Value predictions.
encode_observations
Encode observations into hidden states.Batch of observations. Automatically flattened if dictionary-based.
Optional state dictionary.
Encoded hidden states with shape
(batch_size, hidden_size).decode_actions
Decode hidden states into action logits and values.Hidden states from
encode_observations.Action logits or distribution.
Value predictions.
LSTMWrapper
Wraps a policy with an LSTM layer for temporal processing.Parameters
Environment instance.
Base policy to wrap. Must implement
encode_observations and decode_actions.Input size to the LSTM (should match base policy hidden size).
LSTM hidden state size.
Properties
LSTM hidden state size.
Input size to the LSTM.
Whether the action space is continuous (inherited from base policy).
Methods
forward
Forward pass for training with full time sequences.Observations with shape
(batch_size, time_steps, *obs_shape) or (batch_size, *obs_shape).State dictionary containing:
lstm_h: LSTM hidden state orNonelstm_c: LSTM cell state orNoneaction: Actions (when computing log probs)
Action logits or distribution.
Value predictions with shape
(batch_size, time_steps).forward_eval
Fast forward pass for inference (3x faster than using LSTM directly).Single timestep observations with shape
(batch_size, *obs_shape).State dictionary. Updated in-place with new LSTM states:
lstm_h: LSTM hidden statelstm_c: LSTM cell statehidden: Encoded hidden state
Action logits or distribution.
Value predictions.
Convolutional
NatureCNN architecture from CleanRL, designed for Atari and visual observations.Parameters
Environment instance.
Number of frames stacked in observation (1 with LSTM, 4 without).
Size of flattened feature map after convolutions.
Input size to final linear layer.
Hidden layer size.
Output size of encoded features.
Whether observations are in channels-last format.
Downsampling factor for observations.
Methods
forward
Forward pass through the network.Image observations with shape
(batch, channels, height, width).Optional state dictionary.
Action logits.
Value predictions.
encode_observations
Encode image observations into features.Image observations (will be normalized to [0, 1] range).
Optional state dictionary.
Encoded features.
decode_actions
Decode features into actions and values.Flattened hidden features.
Action logits.
Value predictions.
ProcgenResnet
ResNet-based architecture from the Impala paper, used for Procgen environments.Parameters
Environment instance.
Width multiplier for convolutional layers.
Width of the final MLP layer.
Methods
forward
Forward pass through the network.Image observations with shape
(batch, height, width, channels).Optional state dictionary.
Action logits.
Value predictions.
encode_observations
Encode image observations through ResNet blocks.Image observations (normalized to [0, 1] internally).
Encoded features.
decode_actions
Decode features into actions and values.Encoded features.
Action logits.
Value predictions.
Usage examples
Custom policies
You can create custom policies by following the PufferLib convention:- Implement
forward()andforward_eval()methods - Return
(logits, values)from forward methods - For LSTM wrapper compatibility, implement
encode_observations()anddecode_actions()