PufferLib supports distributed training using PyTorch’s DistributedDataParallel (DDP) for scaling training across multiple GPUs. This enables linear speedup when training large models or environments.
Overview
Distributed training in PufferLib uses PyTorch’s torchrun utility and the NCCL backend for efficient multi-GPU communication. Each process runs its own environment instances and synchronizes gradients across all GPUs.
Distributed training requires NCCL (NVIDIA Collective Communications Library). Ensure your PyTorch installation includes CUDA support.
Quick start
Launch distributed training using torchrun:
torchrun --standalone --nnodes=1 --nproc-per-node=4 \
-m pufferlib.pufferl train your_env
This command launches training on 4 GPUs on a single node.
Distributed setup
Environment variables
PufferLib automatically detects distributed training through PyTorch environment variables:
LOCAL_RANK: GPU rank on the current node
WORLD_SIZE: Total number of processes
MASTER_ADDR: Address of the master node (default: localhost)
MASTER_PORT: Port for communication (default: 29500)
Initialization flow
From pufferlib/pufferl.py:914-939:
# Detect distributed training
if 'LOCAL_RANK' in os.environ:
world_size = int(os.environ.get('WORLD_SIZE', 1))
master_addr = os.environ.get('MASTER_ADDR', 'localhost')
master_port = os.environ.get('MASTER_PORT', '29500')
local_rank = int(os.environ["LOCAL_RANK"])
# Set device
torch.cuda.set_device(local_rank)
os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank)
# Initialize process group
if 'LOCAL_RANK' in os.environ:
args['train']['device'] = torch.cuda.current_device()
torch.distributed.init_process_group(backend='nccl', world_size=world_size)
# Wrap policy with DDP
policy = policy.to(local_rank)
model = torch.nn.parallel.DistributedDataParallel(
policy, device_ids=[local_rank], output_device=local_rank
)
Configuration
Single node, multiple GPUs
Train on all GPUs on a single machine:
torchrun --standalone --nnodes=1 --nproc-per-node=8 \
-m pufferlib.pufferl train atari
Multiple nodes
For multi-node training, specify the master node:
Launch on master node
torchrun --nnodes=2 --nproc-per-node=8 \
--master_addr=192.168.1.100 --master_port=29500 \
--node_rank=0 \
-m pufferlib.pufferl train your_env
Launch on worker nodes
torchrun --nnodes=2 --nproc-per-node=8 \
--master_addr=192.168.1.100 --master_port=29500 \
--node_rank=1 \
-m pufferlib.pufferl train your_env
Training configuration
Distributed training shares the same configuration as single-GPU training:
import pufferlib
import pufferlib.vector
# Environment and policy setup is identical
vecenv = pufferlib.vector.make(
env_creator,
num_envs=128,
backend='Multiprocessing'
)
# Training will automatically use DDP if launched with torchrun
config = {
'device': 'cuda',
'batch_size': 32768,
'learning_rate': 3e-4,
# ... other config
}
pufferl = pufferlib.PuffeRL(config, vecenv, policy)
Logging and checkpoints
Rank 0 only operations
Only rank 0 performs logging and checkpoint saving to avoid conflicts:
if torch.distributed.is_initialized():
if torch.distributed.get_rank() != 0:
# Non-master ranks skip logging
self.logger.log(logs, agent_steps)
return logs
else:
return None
From pufferlib/pufferl.py:505-510
Checkpoint saving
Checkpoints are saved only by rank 0:
def save_checkpoint(self):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() != 0:
return
# Only rank 0 saves
torch.save(self.uncompiled_policy.state_dict(), model_path)
From pufferlib/pufferl.py:524-527
Gradient synchronization
Automatic synchronization
DDP automatically synchronizes gradients during backward pass. Each process computes gradients on its local batch, and DDP averages them across all processes.
Distributed aggregation
Metrics are aggregated across all processes using all_reduce:
def dist_sum(value, device):
if not torch.distributed.is_initialized():
return value
tensor = torch.tensor(value, device=device)
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
return tensor.item()
def dist_mean(value, device):
if not torch.distributed.is_initialized():
return value
return dist_sum(value, device) / torch.distributed.get_world_size()
From pufferlib/pufferl.py:708-720
For optimal performance:
- Use 1 process per GPU
- Ensure batch size is divisible by world size
- Use NCCL backend for GPU training
- Place environments on CPU to maximize GPU utilization
GPU utilization monitoring
Note that GPU utilization monitoring is disabled in distributed mode to avoid NVML conflicts:
if torch.cuda.is_available():
if torch.distributed.is_initialized():
time.sleep(self.delay)
continue
self.gpu_util.append(torch.cuda.utilization())
From pufferlib/pufferl.py:786-790
Scaling batch size
When scaling to multiple GPUs, increase the total batch size proportionally:
# Single GPU
config = {'batch_size': 32768}
# 4 GPUs - scale batch size
config = {'batch_size': 131072} # 32768 * 4
Troubleshooting
NCCL timeout errors
Increase the timeout if experiencing slow environments:
export NCCL_TIMEOUT=1800 # 30 minutes
torchrun --standalone --nnodes=1 --nproc-per-node=4 \
-m pufferlib.pufferl train your_env
Device placement errors
Ensure all tensors are on the correct device:
# Correct device assignment
args['train']['device'] = torch.cuda.current_device()
policy = policy.to(local_rank)
Out of memory
Reduce per-GPU batch size or enable CPU offloading:
config = {
'batch_size': 16384, # Smaller per-GPU batch
'cpu_offload': True, # Offload observations to CPU
}
Example: 4-GPU training
import pufferlib
import pufferlib.vector
def make_env():
# Your environment creation
return env
# This code runs on all ranks
vecenv = pufferlib.vector.make(
make_env,
num_envs=128,
num_workers=16,
backend='Multiprocessing'
)
policy = YourPolicy(vecenv.driver_env)
config = {
'device': 'cuda', # Will be set automatically in DDP
'batch_size': 131072, # 32768 * 4 GPUs
'learning_rate': 3e-4,
'total_timesteps': 1_000_000_000,
}
# Launch with:
# torchrun --standalone --nnodes=1 --nproc-per-node=4 train.py
pufferlib.train('your_env', args={'train': config}, vecenv=vecenv, policy=policy)
Monitor training with wandb or neptune. Only rank 0 will log metrics, preventing duplicate entries.