Wrapper types
There are three main types of wrappers you might create:- Preprocessing wrappers - Modify observations or rewards (Gymnasium/PettingZoo API)
- Integration wrappers - Add new environment libraries to PufferLib
- Native PufferEnv - Implement environments directly in the PufferEnv interface
Preprocessing wrappers
Preprocessing wrappers modify the environment before conversion to PufferEnv. They follow the standard Gymnasium or PettingZoo wrapper pattern.Gymnasium preprocessing wrapper
Here’s a simple wrapper that normalizes observations:import gymnasium
import numpy as np
class NormalizeObservation(gymnasium.Wrapper):
"""Normalize observations to [-1, 1] range"""
def __init__(self, env):
super().__init__(env)
self.observation_space = gymnasium.spaces.Box(
low=-1.0,
high=1.0,
shape=env.observation_space.shape,
dtype=np.float32
)
self.obs_low = env.observation_space.low
self.obs_high = env.observation_space.high
def normalize(self, obs):
# Scale to [-1, 1]
return 2 * (obs - self.obs_low) / (self.obs_high - self.obs_low) - 1
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
return self.normalize(obs), info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self.normalize(obs), reward, terminated, truncated, info
import gymnasium
import pufferlib.emulation
env = gymnasium.make('CartPole-v1')
env = NormalizeObservation(env) # Apply custom wrapper
env = pufferlib.emulation.GymnasiumPufferEnv(env=env)
Real example: ResizeObservation
PufferLib’sResizeObservation wrapper efficiently downscales images. From pufferlib/pufferlib.py:114:
import gymnasium
import numpy as np
class ResizeObservation(gymnasium.Wrapper):
"""Fixed downscaling wrapper using fast NumPy slicing.
Much faster than OpenCV-based resizing (-50% overhead on Atari).
"""
def __init__(self, env, downscale=2):
super().__init__(env)
self.downscale = downscale
y_size, x_size = env.observation_space.shape
assert y_size % downscale == 0 and x_size % downscale == 0
y_size = y_size // downscale
x_size = x_size // downscale
self.observation_space = gymnasium.spaces.Box(
low=0, high=255, shape=(y_size, x_size), dtype=np.uint8
)
def reset(self, seed=None, options=None):
obs, info = self.env.reset(seed=seed, options=options)
return obs[::self.downscale, ::self.downscale], info
def step(self, action):
obs, reward, terminal, truncated, info = self.env.step(action)
return obs[::self.downscale, ::self.downscale], reward, terminal, truncated, info
Real example: EpisodeStats
Tracks cumulative episode statistics. Frompufferlib/pufferlib.py:155:
import gymnasium
class EpisodeStats(gymnasium.Wrapper):
"""Store episodic returns and lengths in info dict"""
def __init__(self, env):
self.env = env
self.observation_space = env.observation_space
self.action_space = env.action_space
self.reset()
def reset(self, seed=None, options=None):
self.info = dict(episode_return=[], episode_length=0)
return self.env.reset(seed=seed)
def step(self, action):
observation, reward, terminated, truncated, info = super().step(action)
# Accumulate nested info dicts
for k, v in unroll_nested_dict(info):
if k not in self.info:
self.info[k] = []
self.info[k].append(v)
self.info['episode_return'].append(reward)
self.info['episode_length'] += 1
# On episode end, sum/aggregate stats
info = {}
if terminated or truncated:
for k, v in self.info.items():
try:
info[k] = sum(v) # Sum numeric values
except TypeError:
info[k] = v # Keep non-numeric as-is
return observation, reward, terminated, truncated, info
PettingZoo preprocessing wrapper
For multi-agent environments, usePettingZooWrapper as a base:
import pufferlib
class PettingZooWrapper:
"""Base wrapper for PettingZoo environments"""
def __init__(self, env):
self.env = env
def __getattr__(self, name):
if name.startswith('_') and name != '_cumulative_rewards':
raise AttributeError(f'accessing private attribute "{name}" is prohibited')
return getattr(self.env, name)
@property
def unwrapped(self):
return self.env.unwrapped
def close(self):
self.env.close()
def render(self):
return self.env.render()
def reset(self, seed=None, options=None):
try:
return self.env.reset(seed=seed, options=options)
except TypeError:
return self.env.reset(seed=seed)
def step(self, action):
return self.env.step(action)
Real example: MeanOverAgents
Averages info dicts across agents. Frompufferlib/pufferlib.py:250:
import numpy as np
class MeanOverAgents(PettingZooWrapper):
"""Averages info values over all agents"""
def _mean(self, infos):
list_infos = {}
for agent, info in infos.items():
for k, v in info.items():
if k not in list_infos:
list_infos[k] = []
list_infos[k].append(v)
mean_infos = {}
for k, v in list_infos.items():
try:
mean_infos[k] = np.mean(v)
except:
pass
return mean_infos
def reset(self, seed=None, options=None):
observations, infos = super().reset(seed, options)
infos = self._mean(infos)
return observations, infos
def step(self, actions):
observations, rewards, terminations, truncations, infos = super().step(actions)
infos = self._mean(infos)
return observations, rewards, terminations, truncations, infos
Environment-specific wrappers
Some environments need custom preprocessing. Here are real examples from PufferLib.MiniGrid wrapper
Removes the mission string from observations. Frompufferlib/environments/minigrid/environment.py:27:
import gymnasium
class MiniGridWrapper:
"""Remove mission string from MiniGrid observations"""
def __init__(self, env):
self.env = env
# Filter out 'mission' from observation space
self.observation_space = gymnasium.spaces.Dict({
k: v for k, v in env.observation_space.items()
if k != 'mission'
})
self.action_space = env.action_space
self.close = env.close
self.render = env.render
self.render_mode = 'rgb_array'
def reset(self, seed=None, options=None):
self.tick = 0
obs, info = self.env.reset(seed=seed)
del obs['mission'] # Remove mission string
return obs, info
def step(self, action):
obs, reward, done, truncated, info = self.env.step(action)
del obs['mission']
# Add timeout
self.tick += 1
if self.tick == 100:
done = True
return obs, reward, done, truncated, info
NetHack wrapper
Renders text as images for visual observation. Frompufferlib/environments/nethack/wrapper.py:136:
import gym
import numpy as np
from numba import njit
class RenderCharImagesWithNumpyWrapper(gym.Wrapper):
"""Render NetHack characters as RGB images"""
def __init__(self, env, font_size=9, crop_size=12, rescale_font_size=(6, 6)):
super().__init__(env)
# Pre-render all characters in all colors
self.char_array = _initialize_char_array(font_size, rescale_font_size)
self.char_height = self.char_array.shape[2]
self.char_width = self.char_array.shape[3]
self.char_array = self.char_array.transpose(0, 1, 4, 2, 3) # CHW format
self.crop_size = crop_size
self.output_height_chars = crop_size
self.output_width_chars = crop_size
self.chw_image_shape = (
3,
self.output_height_chars * self.char_height,
self.output_width_chars * self.char_width,
)
# Replace text observation with image
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8
)
self.render_mode = 'rgb_array'
def _render_text_to_image(self, obs):
chars = obs["tty_chars"]
colors = obs["tty_colors"]
center_y, center_x = obs["tty_cursor"]
# Crop around player
offset_h = center_y - self.crop_size // 2
offset_w = center_x - self.crop_size // 2
out_image = np.zeros(self.chw_image_shape, dtype=np.uint8)
# Fast rendering using pre-cached character images
_tile_characters_to_image(
out_image, chars, colors,
self.output_height_chars, self.output_width_chars,
self.char_array, offset_h, offset_w
)
return out_image
def reset(self):
obs = self.env.reset()
self.obs = obs
return self._render_text_to_image(obs)
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.obs = obs
return self._render_text_to_image(obs), reward, done, info
Creating environment integrations
To integrate a new environment library, create a module withmake() and env_creator() functions.
Basic structure
# pufferlib/environments/my_env/environment.py
import functools
import pufferlib
import pufferlib.emulation
import pufferlib.environments
def env_creator(name='default'):
"""Returns a partial function for creating environments"""
return functools.partial(make, name)
def make(name, render_mode='rgb_array', buf=None, seed=0):
"""Create and wrap the environment"""
# Import the base environment
my_env = pufferlib.environments.try_import('my_env_package')
# Create the environment
env = my_env.make(name, render_mode=render_mode)
# Apply preprocessing wrappers
env = pufferlib.EpisodeStats(env)
# Convert to PufferEnv interface
return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf)
Real example: Atari integration
Frompufferlib/environments/atari/environment.py:11:
import functools
import numpy as np
import gymnasium as gym
import pufferlib
import pufferlib.emulation
import pufferlib.environments
def env_creator(name='breakout'):
return functools.partial(make, name)
def make(name='breakout', obs_type='grayscale', frameskip=4,
full_action_space=False, framestack=1,
repeat_action_probability=0.0, render_mode='rgb_array',
buf=None, seed=0):
"""Atari creation function"""
# Check if ALE is installed
pufferlib.environments.try_import('ale_py', 'AtariEnv')
from ale_py import AtariEnv
env = AtariEnv(
name,
obs_type=obs_type,
frameskip=frameskip,
repeat_action_probability=repeat_action_probability,
full_action_space=full_action_space,
render_mode=render_mode
)
# Fast downscaling
env = pufferlib.ResizeObservation(env, downscale=2)
# Optional frame stacking
if framestack > 1:
env = gym.wrappers.FrameStack(env, framestack)
# Atari-specific postprocessor
env = AtariPostprocessor(env)
env = pufferlib.EpisodeStats(env)
return pufferlib.emulation.GymnasiumPufferEnv(env=env, buf=buf)
class AtariPostprocessor(gym.Wrapper):
"""Transpose observations to CHW format"""
def __init__(self, env):
super().__init__(env)
shape = env.observation_space.shape
if len(shape) < 3:
shape = (1, *shape)
else:
shape = (shape[2], shape[0], shape[1]) # HWC -> CHW
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=shape, dtype=env.observation_space.dtype
)
def unsqueeze_transpose(self, obs):
if len(obs.shape) == 3:
return np.transpose(obs, (2, 0, 1))
else:
return np.expand_dims(obs, 0)
def reset(self, seed=None, options=None):
obs, _ = self.env.reset(seed=seed)
return self.unsqueeze_transpose(obs), {}
def step(self, action):
obs, reward, terminal, truncated, _ = self.env.step(action)
return self.unsqueeze_transpose(obs), reward, terminal, truncated, {}
Real example: PettingZoo integration
Frompufferlib/environments/butterfly/environment.py:8:
import functools
from pettingzoo.utils.conversions import aec_to_parallel_wrapper
import pufferlib.emulation
import pufferlib.environments
def env_creator(name='cooperative_pong_v5'):
return functools.partial(make, name)
def make(name, buf=None):
# Check if PettingZoo butterfly is installed
pufferlib.environments.try_import('pettingzoo.butterfly', 'butterfly')
if name == 'cooperative_pong_v5':
from pettingzoo.butterfly import cooperative_pong_v5 as pong
env_cls = pong.raw_env
elif name == 'knights_archers_zombies_v10':
from pettingzoo.butterfly import knights_archers_zombies_v10 as kaz
env_cls = kaz.raw_env
else:
raise ValueError(f'Unknown environment: {name}')
# Create and convert AEC to parallel
env = env_cls()
env = aec_to_parallel_wrapper(env)
return pufferlib.emulation.PettingZooPufferEnv(env=env, buf=buf)
Implementing native PufferEnv
For maximum performance, implement environments directly in the PufferEnv interface. This avoids emulation overhead and allows full control over vectorization.PufferEnv API
Frompufferlib/pufferlib.py:45:
import numpy as np
import gymnasium
import pufferlib
import pufferlib.spaces
class PufferEnv:
"""Base class for native PufferLib environments"""
def __init__(self, buf=None):
# Required attributes (must be set before calling super().__init__)
# self.single_observation_space = ...
# self.single_action_space = ...
# self.num_agents = ...
if not hasattr(self, 'single_observation_space'):
raise pufferlib.APIUsageError('Must define single_observation_space')
if not hasattr(self, 'single_action_space'):
raise pufferlib.APIUsageError('Must define single_action_space')
if not hasattr(self, 'num_agents'):
raise pufferlib.APIUsageError('Must define num_agents')
if self.num_agents < 1:
raise pufferlib.APIUsageError('num_agents must be >= 1')
# Validate spaces
if not isinstance(self.single_observation_space, pufferlib.spaces.Box):
raise pufferlib.APIUsageError('observation_space must be a Box')
# Set up buffers
pufferlib.set_buffers(self, buf)
# Create batched spaces
self.action_space = pufferlib.spaces.joint_space(
self.single_action_space, self.num_agents
)
self.observation_space = pufferlib.spaces.joint_space(
self.single_observation_space, self.num_agents
)
self.agent_ids = np.arange(self.num_agents)
@property
def emulated(self):
"""Native envs do not use emulation"""
return False
@property
def done(self):
"""Native envs handle resets internally"""
return False
def reset(self, seed=None):
"""Reset environment and return (observations, infos)"""
raise NotImplementedError
def step(self, actions):
"""Step environment and return (obs, rewards, terminals, truncations, infos)"""
raise NotImplementedError
def close(self):
"""Clean up resources"""
raise NotImplementedError
Example: Simple native environment
Fromexamples/puffer_env.py:
import gymnasium
import numpy as np
import pufferlib
class SamplePufferEnv(pufferlib.PufferEnv):
"""A simple native PufferEnv implementation"""
def __init__(self, buf=None, seed=0):
# Define spaces BEFORE calling super().__init__()
self.single_observation_space = gymnasium.spaces.Box(
low=-1, high=1, shape=(4,), dtype=np.float32
)
self.single_action_space = gymnasium.spaces.Discrete(2)
self.num_agents = 2
# Initialize buffers
super().__init__(buf)
self.rng = np.random.RandomState(seed)
def reset(self, seed=None):
if seed is not None:
self.rng = np.random.RandomState(seed)
# Fill observation buffer
self.observations[:] = self.rng.uniform(-1, 1, self.observations.shape)
# Return observations and empty info list
return self.observations, [{}]
def step(self, actions):
# Update observations in-place
self.observations[:] = self.rng.uniform(-1, 1, self.observations.shape)
# Set rewards
self.rewards[:] = self.rng.randn(self.num_agents)
# Set terminals and truncations
self.terminals[:] = False
self.truncations[:] = False
# Return tuple with info as list of dicts
infos = [{'custom_metric': i} for i in range(self.num_agents)]
return self.observations, self.rewards, self.terminals, self.truncations, infos
def close(self):
pass
env = SamplePufferEnv()
obs, info = env.reset()
for _ in range(100):
actions = env.action_space.sample()
obs, rewards, terminals, truncations, info = env.step(actions)
print(f"Observations: {obs}")
print(f"Rewards: {rewards}")
Best practices
Use try_import for dependencies
Always usepufferlib.environments.try_import() to check for environment dependencies:
import pufferlib.environments
# This gives a helpful error message if ale_py is not installed
pufferlib.environments.try_import('ale_py', 'AtariEnv')
Provide env_creator function
Always provide anenv_creator() function for use with vectorization:
import functools
def env_creator(name='default'):
return functools.partial(make, name)
Use episode statistics
AddEpisodeStats wrapper to track episode returns and lengths:
env = pufferlib.EpisodeStats(env)
Document your wrapper
Include docstrings explaining what your wrapper does and any special requirements:def make(name, buf=None):
"""Create MyEnv environment.
Args:
name: Environment name or task
buf: Optional pre-allocated buffers
Returns:
PufferLib-wrapped environment
"""
pass
Handle seed properly
Support seeding in your wrappers:def reset(self, seed=None, options=None):
if seed is not None:
self.env.seed(seed)
return self.env.reset()
Next steps
Wrapper overview
Learn about the wrapper system
Ocean environments
Browse all Ocean environments