PufferLib includes CUDA extensions for performance-critical operations. The main CUDA kernel implements the advantage calculation used in PPO training.
Overview
PufferLib’s CUDA extensions provide:
- CUDA advantage kernel: Accelerated GAE-Lambda and V-trace computation
- Automatic fallback: CPU implementation when CUDA unavailable
- PyTorch integration: Registered as custom PyTorch operators
CUDA extensions are automatically built during installation if CUDA is detected.
Build requirements
System requirements
CUDA toolkit
Install CUDA development tools (not just runtime):# Ubuntu/Debian
sudo apt install nvidia-cuda-toolkit
# Verify installation
nvcc --version
PyTorch with CUDA
Install PyTorch with CUDA support:pip install torch --index-url https://download.pytorch.org/whl/cu121
C++ compiler
Ensure a compatible C++ compiler is available:# Ubuntu/Debian
sudo apt install build-essential
# Verify
g++ --version
Build detection
From setup.py:25-27:
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
# Build CUDA extension if torch can find CUDA or HIP/ROCM
BUID_CUDA_EXT = bool(CUDA_HOME or ROCM_HOME)
Installation
Standard installation
CUDA extensions build automatically if CUDA is detected:
Manual build
Build from source with custom options:
# Clone repository
git clone https://github.com/PufferAI/PufferLib.git
cd PufferLib
# Build with CUDA
python setup.py build_ext --inplace
# Or install
pip install -e .
Build without isolation
For custom PyTorch builds:
pip install --no-build-isolation -e .
If you have a non-default PyTorch installation, you may need --no-build-isolation to avoid build errors.
From pufferlib/pufferl.py:34-37
Extension implementation
CUDA kernel structure
From pufferlib/extensions/cuda/pufferlib.cu:7-20:
__host__ __device__ void puff_advantage_row_cuda(
float* values, float* rewards, float* dones,
float* importance, float* advantages,
float gamma, float lambda,
float rho_clip, float c_clip, int horizon
) {
float lastpufferlam = 0;
for (int t = horizon-2; t >= 0; t--) {
int t_next = t + 1;
float nextnonterminal = 1.0 - dones[t_next];
float rho_t = fminf(importance[t], rho_clip);
float c_t = fminf(importance[t], c_clip);
float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]);
lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal;
advantages[t] = lastpufferlam;
}
}
Kernel launch
From pufferlib/extensions/cuda/pufferlib.cu:41-51:
__global__ void puff_advantage_kernel(
float* values, float* rewards,
float* dones, float* importance, float* advantages,
float gamma, float lambda, float rho_clip, float c_clip,
int num_steps, int horizon
) {
int row = blockIdx.x*blockDim.x + threadIdx.x;
if (row >= num_steps) {
return;
}
int offset = row*horizon;
puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset,
importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon);
}
Kernel configuration
From pufferlib/extensions/cuda/pufferlib.cu:53-76:
void compute_puff_advantage_cuda(
torch::Tensor values, torch::Tensor rewards,
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
double gamma, double lambda, double rho_clip, double c_clip
) {
int num_steps = values.size(0);
int horizon = values.size(1);
vtrace_check_cuda(values, rewards, dones, importance, advantages, num_steps, horizon);
TORCH_CHECK(values.is_cuda(), "All tensors must be on GPU");
int threads_per_block = 256;
int blocks = (num_steps + threads_per_block - 1) / threads_per_block;
puff_advantage_kernel<<<blocks, threads_per_block>>>(
values.data_ptr<float>(),
rewards.data_ptr<float>(),
dones.data_ptr<float>(),
importance.data_ptr<float>(),
advantages.data_ptr<float>(),
gamma, lambda, rho_clip, c_clip, num_steps, horizon
);
// Error checking
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(err));
}
}
CPU fallback
CPU implementation
From pufferlib/extensions/pufferlib.cpp:28-41:
void puff_advantage_row(
float* values, float* rewards, float* dones,
float* importance, float* advantages,
float gamma, float lambda,
float rho_clip, float c_clip, int horizon
) {
float lastpufferlam = 0;
for (int t = horizon-2; t >= 0; t--) {
int t_next = t + 1;
float nextnonterminal = 1.0 - dones[t_next];
float rho_t = fminf(importance[t], rho_clip);
float c_t = fminf(importance[t], c_clip);
float delta = rho_t*(rewards[t_next] + gamma*values[t_next]*nextnonterminal - values[t]);
lastpufferlam = delta + gamma*lambda*c_t*lastpufferlam*nextnonterminal;
advantages[t] = lastpufferlam;
}
}
Automatic dispatch
From pufferlib/pufferl.py:660-680:
def compute_puff_advantage(
values, rewards, terminals,
ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip
):
'''CUDA kernel for puffer advantage with automatic CPU fallback.'''
device = values.device
if not ADVANTAGE_CUDA:
# Move to CPU for computation
values = values.cpu()
rewards = rewards.cpu()
terminals = terminals.cpu()
ratio = ratio.cpu()
advantages = advantages.cpu()
torch.ops.pufferlib.compute_puff_advantage(
values, rewards, terminals,
ratio, advantages, gamma, gae_lambda, vtrace_rho_clip, vtrace_c_clip
)
if not ADVANTAGE_CUDA:
return advantages.to(device)
return advantages
Build configuration
Extension sources
From setup.py:239-258:
torch_sources = [
"pufferlib/extensions/pufferlib.cpp",
]
if BUID_CUDA_EXT:
extension = CUDAExtension
torch_sources.append("pufferlib/extensions/cuda/pufferlib.cu")
else:
extension = CppExtension
torch_extensions = [
extension(
"pufferlib._C",
torch_sources,
extra_compile_args = {
"cxx": cxx_args,
"nvcc": nvcc_args,
}
),
]
Compile flags
From setup.py:83-120:
cxx_args = [
'-fdiagnostics-color=always',
]
nvcc_args = []
if DEBUG:
extra_compile_args += [
'-O0',
'-g',
'-fsanitize=address,undefined,bounds,pointer-overflow,leak',
'-fno-omit-frame-pointer',
]
nvcc_args += [
'-O0',
'-g',
]
else:
extra_compile_args += [
'-O2',
'-flto',
]
cxx_args += [
'-O3',
]
nvcc_args += [
'-O3',
]
PyTorch operator registration
Operator definition
From pufferlib/extensions/pufferlib.cpp:87-89:
TORCH_LIBRARY(pufferlib, m) {
m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, float gamma, float lambda, float rho_clip, float c_clip) -> ()");
}
CPU implementation
From pufferlib/extensions/pufferlib.cpp:91-93:
TORCH_LIBRARY_IMPL(pufferlib, CPU, m) {
m.impl("compute_puff_advantage", &compute_puff_advantage_cpu);
}
CUDA implementation
From pufferlib/extensions/cuda/pufferlib.cu:84-86:
TORCH_LIBRARY_IMPL(pufferlib, CUDA, m) {
m.impl("compute_puff_advantage", &compute_puff_advantage_cuda);
}
Troubleshooting
Import errors
If the CUDA extension fails to import:
try:
from pufferlib import _C
except ImportError:
raise ImportError(
'Failed to import C/CUDA advantage kernel. '
'If you have non-default PyTorch, try installing with --no-build-isolation'
)
From pufferlib/pufferl.py:34-37
Solution: Rebuild with --no-build-isolation:
pip install --no-build-isolation --force-reinstall pufferlib
CUDA not detected
Check CUDA detection:
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
print(f"CUDA_HOME: {CUDA_HOME}")
print(f"ROCM_HOME: {ROCM_HOME}")
Solution: Set CUDA_HOME environment variable:
export CUDA_HOME=/usr/local/cuda
pip install --force-reinstall pufferlib
Compilation errors
For debug builds with verbose output:
DEBUG=1 python setup.py build_ext --inplace --force
From setup.py:1-3
Version mismatches
Ensure CUDA toolkit version matches PyTorch:
# Check PyTorch CUDA version
python -c "import torch; print(torch.version.cuda)"
# Check CUDA toolkit version
nvcc --version
Major version mismatch between PyTorch CUDA and system CUDA can cause compilation or runtime errors.
Verify CUDA kernel is being used:
import torch
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
ADVANTAGE_CUDA = bool(CUDA_HOME or ROCM_HOME)
print(f"Using CUDA kernels: {ADVANTAGE_CUDA}")
# Benchmark
import time
import pufferlib.pufferl as pufferl
values = torch.randn(1024, 128).cuda()
rewards = torch.randn(1024, 128).cuda()
terminals = torch.zeros(1024, 128).cuda()
ratio = torch.ones(1024, 128).cuda()
advantages = torch.zeros(1024, 128).cuda()
start = time.time()
for _ in range(100):
pufferl.compute_puff_advantage(
values, rewards, terminals, ratio, advantages,
0.99, 0.95, 1.0, 1.0
)
torch.cuda.synchronize()
print(f"Time: {(time.time() - start) * 10:.2f}ms per call")
ROCm/HIP support
PufferLib supports AMD GPUs through ROCm:
# Install PyTorch with ROCm
pip install torch --index-url https://download.pytorch.org/whl/rocm5.7
# Build will automatically detect ROCM_HOME
pip install pufferlib
From setup.py:25-27