LeRobot policies support PyTorch acceleration techniques like torch.compile, mixed precision training, and hardware-specific optimizations. This guide shows you how to enable and configure these features for faster inference and training.
torch.compile
torch.compile is PyTorch’s JIT compiler that optimizes model execution graphs. LeRobot policies include built-in support for compilation.
Enable Compilation
Set compile_model=True in your policy configuration:
from lerobot.policies.pi0 import Pi0Config
config = Pi0Config(
compile_model=True,
compile_mode="max-autotune", # Optimization level
# ... other config
)
policy = Pi0Policy(config)
See lerobot/policies/pi0/configuration_pi0.py for configuration options.
Compilation Modes
PyTorch offers different compilation modes:
Balanced compilation with good speedup and compilation time
Minimize Python overhead, good for small models
Maximum optimization, slower compilation but best performance (recommended for deployment)
# Fast compilation, moderate speedup
config = Pi0Config(
compile_model=True,
compile_mode="default"
)
# Maximum performance (slower first run)
config = Pi0Config(
compile_model=True,
compile_mode="max-autotune"
)
What Gets Compiled
Different policies compile different components:
Pi0/Pi05 Policies:
# From lerobot/policies/pi0/modeling_pi0.py
if config.compile_model:
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
Diffusion Policy:
# From lerobot/policies/diffusion/modeling_diffusion.py
if config.compile_model:
self.unet = torch.compile(self.unet, mode=config.compile_mode)
SAC Policy:
# From lerobot/policies/sac/modeling_sac.py
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
Compilation Workflow
First inference triggers compilation:
policy = Pi0Policy.from_pretrained("lerobot/pi0_model")
policy.config.compile_model = True
policy.config.compile_mode = "max-autotune"
# First call: slow (compilation happens)
print("First inference (with compilation)...")
start = time.time()
action = policy.predict_action_chunk(observation)
print(f"Time: {time.time() - start:.2f}s") # e.g., 30s
# Subsequent calls: fast (using compiled graph)
print("Second inference (using compiled graph)...")
start = time.time()
action = policy.predict_action_chunk(observation)
print(f"Time: {time.time() - start:.2f}s") # e.g., 0.1s
Avoiding Graph Breaks
Some operations cause graph breaks and recompilation. Avoid:
- Dynamic control flow based on tensor values
- In-place operations on inputs
- Python print statements in hot paths
# Bad: Dynamic control flow
if action.mean() > 0.5: # Graph break!
action = action * 2
# Good: Use torch operations
action = torch.where(action.mean() > 0.5, action * 2, action)
# Bad: Print in hot path
print(f"Action: {action}") # Graph break!
# Good: Use torch.compile.disable
with torch.compiler.disable():
print(f"Action: {action}")
Hardware Acceleration
CUDA (NVIDIA GPUs)
Use CUDA for NVIDIA GPUs:
policy = policy.to("cuda")
# Enable TF32 for faster matmul on Ampere+ GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Enable cuDNN auto-tuner
torch.backends.cudnn.benchmark = True
MPS (Apple Silicon)
Use MPS for Apple Silicon Macs:
if torch.backends.mps.is_available():
policy = policy.to("mps")
# Note: MPS doesn't support float64
# DeviceProcessorStep automatically converts float64 -> float32
See lerobot/processor/device_processor.py:108 for MPS dtype handling.
CPU Optimization
Optimize CPU inference:
import torch
# Use all CPU cores
torch.set_num_threads(torch.get_num_threads())
# Enable MKL (Intel CPUs)
torch.backends.mkl.enabled = True
# Compile with max-autotune for CPU
config.compile_model = True
config.compile_mode = "max-autotune"
Mixed Precision
Use mixed precision (float16/bfloat16) for faster training and inference:
Automatic Mixed Precision (AMP)
For training:
import torch
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast():
loss = policy.compute_loss(batch)
# Backward pass with gradient scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Manual Precision Control
Control dtype via DeviceProcessorStep:
from lerobot.processor import DeviceProcessorStep
# Use bfloat16 (better for training)
device_processor = DeviceProcessorStep(
device="cuda",
float_dtype="bfloat16"
)
# Use float16 (better for inference)
device_processor = DeviceProcessorStep(
device="cuda",
float_dtype="float16"
)
See lerobot/processor/device_processor.py:34 for dtype options.
Precision Recommendations
Training: Use bfloat16 for better numerical stabilityInference: Use float16 for maximum speedCPU: Use float32 (CPU doesn’t support bfloat16)
Optimization Checklist
For Inference
import torch
# 1. Enable torch.compile
config.compile_model = True
config.compile_mode = "max-autotune"
# 2. Use appropriate device
policy = policy.to("cuda")
# 3. Enable hardware optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
# 4. Use mixed precision
device_processor = DeviceProcessorStep(device="cuda", float_dtype="float16")
# 5. Disable gradient computation
torch.set_grad_enabled(False)
# 6. Use eval mode
policy.eval()
For Training
import torch
from torch.cuda.amp import autocast, GradScaler
# 1. Optionally enable torch.compile (slower first epoch)
config.compile_model = True
config.compile_mode = "default" # Faster compilation
# 2. Use GPU
policy = policy.to("cuda")
# 3. Enable hardware optimizations
torch.backends.cuda.matmul.allow_tf32 = True
# 4. Use AMP for mixed precision
scaler = GradScaler()
# 5. Use train mode
policy.train()
# 6. Enable cuDNN auto-tuner (for fixed input sizes)
torch.backends.cudnn.benchmark = True
Benchmarking
Measure speedup from optimizations:
import torch
import time
def benchmark_policy(policy, observation, num_iterations=100, warmup=10):
"""Benchmark policy inference speed."""
# Warmup
for _ in range(warmup):
_ = policy.predict_action_chunk(observation)
# Benchmark
torch.cuda.synchronize() # Wait for GPU
start = time.time()
for _ in range(num_iterations):
_ = policy.predict_action_chunk(observation)
torch.cuda.synchronize()
elapsed = time.time() - start
avg_time = elapsed / num_iterations
fps = 1.0 / avg_time
print(f"Average time: {avg_time*1000:.2f}ms")
print(f"Throughput: {fps:.1f} FPS")
return avg_time
# Compare configurations
print("Baseline (float32, no compile):")
policy1 = load_policy(compile_model=False, dtype="float32")
time1 = benchmark_policy(policy1, obs)
print("\nWith torch.compile:")
policy2 = load_policy(compile_model=True, dtype="float32")
time2 = benchmark_policy(policy2, obs)
print("\nWith torch.compile + float16:")
policy3 = load_policy(compile_model=True, dtype="float16")
time3 = benchmark_policy(policy3, obs)
print(f"\nSpeedup from compile: {time1/time2:.2f}x")
print(f"Speedup from compile+fp16: {time1/time3:.2f}x")
Example output:
Baseline (float32, no compile):
Average time: 45.3ms
Throughput: 22.1 FPS
With torch.compile:
Average time: 12.1ms
Throughput: 82.6 FPS
With torch.compile + float16:
Average time: 6.8ms
Throughput: 147.1 FPS
Speedup from compile: 3.74x
Speedup from compile+fp16: 6.66x
Troubleshooting
Issue: Compilation Takes Too Long
Solution: Use faster compile mode or disable for development:
# For development: no compilation
config.compile_model = False
# For faster compilation: use "default" mode
config.compile_mode = "default"
# For deployment: use "max-autotune" and cache
config.compile_mode = "max-autotune"
Issue: Out of Memory with Mixed Precision
Solution: Reduce batch size or use gradient accumulation:
# Reduce batch size
config.batch_size = config.batch_size // 2
# Or use gradient accumulation
accumulation_steps = 2
for i, batch in enumerate(dataloader):
with autocast():
loss = policy.compute_loss(batch) / accumulation_steps
scaler.scale(loss).backward()
if (i + 1) % accumulation_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Issue: Numerical Instability with float16
Solution: Use bfloat16 or gradient scaling:
# Option 1: Use bfloat16 (better numerical stability)
device_processor = DeviceProcessorStep(device="cuda", float_dtype="bfloat16")
# Option 2: Use GradScaler with float16
from torch.cuda.amp import GradScaler
scaler = GradScaler()
Issue: MPS “float64 not supported” Error
Solution: DeviceProcessorStep handles this automatically:
# This automatically converts float64 -> float32 on MPS
device_processor = DeviceProcessorStep(device="mps")
See lerobot/processor/device_processor.py:108.
Advanced: Custom Compilation
Compile specific functions:
import torch
class MyPolicy:
def __init__(self):
# Compile specific methods
self.process_observation = torch.compile(
self.process_observation,
mode="reduce-overhead"
)
self.predict_action = torch.compile(
self.predict_action,
mode="max-autotune"
)
def process_observation(self, obs):
# Hot path: compile for low overhead
...
def predict_action(self, obs):
# Performance critical: compile for max speed
...
API Reference
torch.compile Settings
Enable torch.compile optimization
compile_mode
str
default:"max-autotune"
Compilation mode: “default”, “reduce-overhead”, or “max-autotune”
DeviceProcessorStep
See lerobot/processor/device_processor.py:34
Target device: “cpu”, “cuda”, “cuda:0”, “mps”
Target float dtype: “float16”, “bfloat16”, “float32”, “float64”
Hardware Backends
torch.backends.cuda.matmul.allow_tf32
Enable TF32 for matrix multiplication (Ampere+ GPUs)
torch.backends.cudnn.benchmark
Enable cuDNN auto-tuner for fixed input sizes
torch.backends.cudnn.allow_tf32
Enable TF32 for cuDNN operations