Skip to main content
PufferLib’s wrapper system is designed to be extensible. You can create custom wrappers to add preprocessing, modify observations, or integrate new environment libraries.

Wrapper types

There are three main types of wrappers you might create:
  1. Preprocessing wrappers - Modify observations or rewards (Gymnasium/PettingZoo API)
  2. Integration wrappers - Add new environment libraries to PufferLib
  3. Native PufferEnv - Implement environments directly in the PufferEnv interface

Preprocessing wrappers

Preprocessing wrappers modify the environment before conversion to PufferEnv. They follow the standard Gymnasium or PettingZoo wrapper pattern.

Gymnasium preprocessing wrapper

Here’s a simple wrapper that normalizes observations:
import gymnasium
import numpy as np

class NormalizeObservation(gymnasium.Wrapper):
    """Normalize observations to [-1, 1] range"""
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gymnasium.spaces.Box(
            low=-1.0, 
            high=1.0,
            shape=env.observation_space.shape,
            dtype=np.float32
        )
        self.obs_low = env.observation_space.low
        self.obs_high = env.observation_space.high
    
    def normalize(self, obs):
        # Scale to [-1, 1]
        return 2 * (obs - self.obs_low) / (self.obs_high - self.obs_low) - 1
    
    def reset(self, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        return self.normalize(obs), info
    
    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        return self.normalize(obs), reward, terminated, truncated, info
Usage:
import gymnasium
import pufferlib.emulation

env = gymnasium.make('CartPole-v1')
env = NormalizeObservation(env)  # Apply custom wrapper
env = pufferlib.emulation.GymnasiumPufferEnv(env=env)

Real example: ResizeObservation

PufferLib’s ResizeObservation wrapper efficiently downscales images. From pufferlib/pufferlib.py:114:
import gymnasium
import numpy as np

class ResizeObservation(gymnasium.Wrapper):
    """Fixed downscaling wrapper using fast NumPy slicing.
    
    Much faster than OpenCV-based resizing (-50% overhead on Atari).
    """
    def __init__(self, env, downscale=2):
        super().__init__(env)
        self.downscale = downscale
        
        y_size, x_size = env.observation_space.shape
        assert y_size % downscale == 0 and x_size % downscale == 0
        
        y_size = y_size // downscale
        x_size = x_size // downscale
        
        self.observation_space = gymnasium.spaces.Box(
            low=0, high=255, shape=(y_size, x_size), dtype=np.uint8
        )
    
    def reset(self, seed=None, options=None):
        obs, info = self.env.reset(seed=seed, options=options)
        return obs[::self.downscale, ::self.downscale], info
    
    def step(self, action):
        obs, reward, terminal, truncated, info = self.env.step(action)
        return obs[::self.downscale, ::self.downscale], reward, terminal, truncated, info

Real example: EpisodeStats

Tracks cumulative episode statistics. From pufferlib/pufferlib.py:155:
import gymnasium

class EpisodeStats(gymnasium.Wrapper):
    """Store episodic returns and lengths in info dict"""
    def __init__(self, env):
        self.env = env
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        self.reset()
    
    def reset(self, seed=None, options=None):
        self.info = dict(episode_return=[], episode_length=0)
        return self.env.reset(seed=seed)
    
    def step(self, action):
        observation, reward, terminated, truncated, info = super().step(action)
        
        # Accumulate nested info dicts
        for k, v in unroll_nested_dict(info):
            if k not in self.info:
                self.info[k] = []
            self.info[k].append(v)
        
        self.info['episode_return'].append(reward)
        self.info['episode_length'] += 1
        
        # On episode end, sum/aggregate stats
        info = {}
        if terminated or truncated:
            for k, v in self.info.items():
                try:
                    info[k] = sum(v)  # Sum numeric values
                except TypeError:
                    info[k] = v  # Keep non-numeric as-is
        
        return observation, reward, terminated, truncated, info

PettingZoo preprocessing wrapper

For multi-agent environments, use PettingZooWrapper as a base:
import pufferlib

class PettingZooWrapper:
    """Base wrapper for PettingZoo environments"""
    def __init__(self, env):
        self.env = env
    
    def __getattr__(self, name):
        if name.startswith('_') and name != '_cumulative_rewards':
            raise AttributeError(f'accessing private attribute "{name}" is prohibited')
        return getattr(self.env, name)
    
    @property
    def unwrapped(self):
        return self.env.unwrapped
    
    def close(self):
        self.env.close()
    
    def render(self):
        return self.env.render()
    
    def reset(self, seed=None, options=None):
        try:
            return self.env.reset(seed=seed, options=options)
        except TypeError:
            return self.env.reset(seed=seed)
    
    def step(self, action):
        return self.env.step(action)

Real example: MeanOverAgents

Averages info dicts across agents. From pufferlib/pufferlib.py:250:
import numpy as np

class MeanOverAgents(PettingZooWrapper):
    """Averages info values over all agents"""
    def _mean(self, infos):
        list_infos = {}
        for agent, info in infos.items():
            for k, v in info.items():
                if k not in list_infos:
                    list_infos[k] = []
                list_infos[k].append(v)
        
        mean_infos = {}
        for k, v in list_infos.items():
            try:
                mean_infos[k] = np.mean(v)
            except:
                pass
        
        return mean_infos
    
    def reset(self, seed=None, options=None):
        observations, infos = super().reset(seed, options)
        infos = self._mean(infos)
        return observations, infos
    
    def step(self, actions):
        observations, rewards, terminations, truncations, infos = super().step(actions)
        infos = self._mean(infos)
        return observations, rewards, terminations, truncations, infos

Environment-specific wrappers

Some environments need custom preprocessing. Here are real examples from PufferLib.

MiniGrid wrapper

Removes the mission string from observations. From pufferlib/environments/minigrid/environment.py:27:
import gymnasium

class MiniGridWrapper:
    """Remove mission string from MiniGrid observations"""
    def __init__(self, env):
        self.env = env
        # Filter out 'mission' from observation space
        self.observation_space = gymnasium.spaces.Dict({
            k: v for k, v in env.observation_space.items() 
            if k != 'mission'
        })
        self.action_space = env.action_space
        self.close = env.close
        self.render = env.render
        self.render_mode = 'rgb_array'
    
    def reset(self, seed=None, options=None):
        self.tick = 0
        obs, info = self.env.reset(seed=seed)
        del obs['mission']  # Remove mission string
        return obs, info
    
    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)
        del obs['mission']
        
        # Add timeout
        self.tick += 1
        if self.tick == 100:
            done = True
        
        return obs, reward, done, truncated, info

NetHack wrapper

Renders text as images for visual observation. From pufferlib/environments/nethack/wrapper.py:136:
import gym
import numpy as np
from numba import njit

class RenderCharImagesWithNumpyWrapper(gym.Wrapper):
    """Render NetHack characters as RGB images"""
    def __init__(self, env, font_size=9, crop_size=12, rescale_font_size=(6, 6)):
        super().__init__(env)
        
        # Pre-render all characters in all colors
        self.char_array = _initialize_char_array(font_size, rescale_font_size)
        self.char_height = self.char_array.shape[2]
        self.char_width = self.char_array.shape[3]
        self.char_array = self.char_array.transpose(0, 1, 4, 2, 3)  # CHW format
        
        self.crop_size = crop_size
        self.output_height_chars = crop_size
        self.output_width_chars = crop_size
        
        self.chw_image_shape = (
            3,
            self.output_height_chars * self.char_height,
            self.output_width_chars * self.char_width,
        )
        
        # Replace text observation with image
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8
        )
        self.render_mode = 'rgb_array'
    
    def _render_text_to_image(self, obs):
        chars = obs["tty_chars"]
        colors = obs["tty_colors"]
        center_y, center_x = obs["tty_cursor"]
        
        # Crop around player
        offset_h = center_y - self.crop_size // 2
        offset_w = center_x - self.crop_size // 2
        
        out_image = np.zeros(self.chw_image_shape, dtype=np.uint8)
        
        # Fast rendering using pre-cached character images
        _tile_characters_to_image(
            out_image, chars, colors,
            self.output_height_chars, self.output_width_chars,
            self.char_array, offset_h, offset_w
        )
        
        return out_image
    
    def reset(self):
        obs = self.env.reset()
        self.obs = obs
        return self._render_text_to_image(obs)
    
    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.obs = obs
        return self._render_text_to_image(obs), reward, done, info

Creating environment integrations

To integrate a new environment library, create a module with make() and env_creator() functions.

Basic structure

# pufferlib/environments/my_env/environment.py
import functools
import pufferlib
import pufferlib.emulation
import pufferlib.environments

def env_creator(name='default'):
    """Returns a partial function for creating environments"""
    return functools.partial(make, name)

def make(name, render_mode='rgb_array', buf=None, seed=0):
    """Create and wrap the environment"""
    # Import the base environment
    my_env = pufferlib.environments.try_import('my_env_package')
    
    # Create the environment
    env = my_env.make(name, render_mode=render_mode)
    
    # Apply preprocessing wrappers
    env = pufferlib.EpisodeStats(env)
    
    # Convert to PufferEnv interface
    return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf)

Real example: Atari integration

From pufferlib/environments/atari/environment.py:11:
import functools
import numpy as np
import gymnasium as gym
import pufferlib
import pufferlib.emulation
import pufferlib.environments

def env_creator(name='breakout'):
    return functools.partial(make, name)

def make(name='breakout', obs_type='grayscale', frameskip=4,
         full_action_space=False, framestack=1,
         repeat_action_probability=0.0, render_mode='rgb_array',
         buf=None, seed=0):
    """Atari creation function"""
    # Check if ALE is installed
    pufferlib.environments.try_import('ale_py', 'AtariEnv')
    
    from ale_py import AtariEnv
    env = AtariEnv(
        name, 
        obs_type=obs_type,
        frameskip=frameskip,
        repeat_action_probability=repeat_action_probability,
        full_action_space=full_action_space,
        render_mode=render_mode
    )
    
    # Fast downscaling
    env = pufferlib.ResizeObservation(env, downscale=2)
    
    # Optional frame stacking
    if framestack > 1:
        env = gym.wrappers.FrameStack(env, framestack)
    
    # Atari-specific postprocessor
    env = AtariPostprocessor(env)
    env = pufferlib.EpisodeStats(env)
    
    return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf)

class AtariPostprocessor(gym.Wrapper):
    """Transpose observations to CHW format"""
    def __init__(self, env):
        super().__init__(env)
        shape = env.observation_space.shape
        if len(shape) < 3:
            shape = (1, *shape)
        else:
            shape = (shape[2], shape[0], shape[1])  # HWC -> CHW
        
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=shape, dtype=env.observation_space.dtype
        )
    
    def unsqueeze_transpose(self, obs):
        if len(obs.shape) == 3:
            return np.transpose(obs, (2, 0, 1))
        else:
            return np.expand_dims(obs, 0)
    
    def reset(self, seed=None, options=None):
        obs, _ = self.env.reset(seed=seed)
        return self.unsqueeze_transpose(obs), {}
    
    def step(self, action):
        obs, reward, terminal, truncated, _ = self.env.step(action)
        return self.unsqueeze_transpose(obs), reward, terminal, truncated, {}

Real example: PettingZoo integration

From pufferlib/environments/butterfly/environment.py:8:
import functools
from pettingzoo.utils.conversions import aec_to_parallel_wrapper
import pufferlib.emulation
import pufferlib.environments

def env_creator(name='cooperative_pong_v5'):
    return functools.partial(make, name)

def make(name, buf=None):
    # Check if PettingZoo butterfly is installed
    pufferlib.environments.try_import('pettingzoo.butterfly', 'butterfly')
    
    if name == 'cooperative_pong_v5':
        from pettingzoo.butterfly import cooperative_pong_v5 as pong
        env_cls = pong.raw_env
    elif name == 'knights_archers_zombies_v10':
        from pettingzoo.butterfly import knights_archers_zombies_v10 as kaz
        env_cls = kaz.raw_env
    else:
        raise ValueError(f'Unknown environment: {name}')
    
    # Create and convert AEC to parallel
    env = env_cls()
    env = aec_to_parallel_wrapper(env)
    
    return pufferlib.emulation.PettingZooPufferEnv(env=env, buf=buf)

Implementing native PufferEnv

For maximum performance, implement environments directly in the PufferEnv interface. This avoids emulation overhead and allows full control over vectorization.

PufferEnv API

From pufferlib/pufferlib.py:45:
import numpy as np
import gymnasium
import pufferlib
import pufferlib.spaces

class PufferEnv:
    """Base class for native PufferLib environments"""
    def __init__(self, buf=None):
        # Required attributes (must be set before calling super().__init__)
        # self.single_observation_space = ...
        # self.single_action_space = ...
        # self.num_agents = ...
        
        if not hasattr(self, 'single_observation_space'):
            raise pufferlib.APIUsageError('Must define single_observation_space')
        if not hasattr(self, 'single_action_space'):
            raise pufferlib.APIUsageError('Must define single_action_space')
        if not hasattr(self, 'num_agents'):
            raise pufferlib.APIUsageError('Must define num_agents')
        if self.num_agents < 1:
            raise pufferlib.APIUsageError('num_agents must be >= 1')
        
        # Validate spaces
        if not isinstance(self.single_observation_space, pufferlib.spaces.Box):
            raise pufferlib.APIUsageError('observation_space must be a Box')
        
        # Set up buffers
        pufferlib.set_buffers(self, buf)
        
        # Create batched spaces
        self.action_space = pufferlib.spaces.joint_space(
            self.single_action_space, self.num_agents
        )
        self.observation_space = pufferlib.spaces.joint_space(
            self.single_observation_space, self.num_agents
        )
        self.agent_ids = np.arange(self.num_agents)
    
    @property
    def emulated(self):
        """Native envs do not use emulation"""
        return False
    
    @property
    def done(self):
        """Native envs handle resets internally"""
        return False
    
    def reset(self, seed=None):
        """Reset environment and return (observations, infos)"""
        raise NotImplementedError
    
    def step(self, actions):
        """Step environment and return (obs, rewards, terminals, truncations, infos)"""
        raise NotImplementedError
    
    def close(self):
        """Clean up resources"""
        raise NotImplementedError

Example: Simple native environment

From examples/puffer_env.py:
import gymnasium
import numpy as np
import pufferlib

class SamplePufferEnv(pufferlib.PufferEnv):
    """A simple native PufferEnv implementation"""
    def __init__(self, buf=None, seed=0):
        # Define spaces BEFORE calling super().__init__()
        self.single_observation_space = gymnasium.spaces.Box(
            low=-1, high=1, shape=(4,), dtype=np.float32
        )
        self.single_action_space = gymnasium.spaces.Discrete(2)
        self.num_agents = 2
        
        # Initialize buffers
        super().__init__(buf)
        
        self.rng = np.random.RandomState(seed)
    
    def reset(self, seed=None):
        if seed is not None:
            self.rng = np.random.RandomState(seed)
        
        # Fill observation buffer
        self.observations[:] = self.rng.uniform(-1, 1, self.observations.shape)
        
        # Return observations and empty info list
        return self.observations, [{}]
    
    def step(self, actions):
        # Update observations in-place
        self.observations[:] = self.rng.uniform(-1, 1, self.observations.shape)
        
        # Set rewards
        self.rewards[:] = self.rng.randn(self.num_agents)
        
        # Set terminals and truncations
        self.terminals[:] = False
        self.truncations[:] = False
        
        # Return tuple with info as list of dicts
        infos = [{'custom_metric': i} for i in range(self.num_agents)]
        return self.observations, self.rewards, self.terminals, self.truncations, infos
    
    def close(self):
        pass
Usage:
env = SamplePufferEnv()
obs, info = env.reset()

for _ in range(100):
    actions = env.action_space.sample()
    obs, rewards, terminals, truncations, info = env.step(actions)
    print(f"Observations: {obs}")
    print(f"Rewards: {rewards}")

Best practices

Use try_import for dependencies

Always use pufferlib.environments.try_import() to check for environment dependencies:
import pufferlib.environments

# This gives a helpful error message if ale_py is not installed
pufferlib.environments.try_import('ale_py', 'AtariEnv')

Provide env_creator function

Always provide an env_creator() function for use with vectorization:
import functools

def env_creator(name='default'):
    return functools.partial(make, name)

Use episode statistics

Add EpisodeStats wrapper to track episode returns and lengths:
env = pufferlib.EpisodeStats(env)

Document your wrapper

Include docstrings explaining what your wrapper does and any special requirements:
def make(name, buf=None):
    """Create MyEnv environment.
    
    Args:
        name: Environment name or task
        buf: Optional pre-allocated buffers
    
    Returns:
        PufferLib-wrapped environment
    """
    pass

Handle seed properly

Support seeding in your wrappers:
def reset(self, seed=None, options=None):
    if seed is not None:
        self.env.seed(seed)
    return self.env.reset()

Next steps

Wrapper overview

Learn about the wrapper system

Ocean environments

Browse all Ocean environments

Build docs developers (and LLMs) love