Skip to main content
LeRobot policies support PyTorch acceleration techniques like torch.compile, mixed precision training, and hardware-specific optimizations. This guide shows you how to enable and configure these features for faster inference and training.

torch.compile

torch.compile is PyTorch’s JIT compiler that optimizes model execution graphs. LeRobot policies include built-in support for compilation.

Enable Compilation

Set compile_model=True in your policy configuration:
from lerobot.policies.pi0 import Pi0Config

config = Pi0Config(
    compile_model=True,
    compile_mode="max-autotune",  # Optimization level
    # ... other config
)

policy = Pi0Policy(config)
See lerobot/policies/pi0/configuration_pi0.py for configuration options.

Compilation Modes

PyTorch offers different compilation modes:
default
str
Balanced compilation with good speedup and compilation time
reduce-overhead
str
Minimize Python overhead, good for small models
max-autotune
str
Maximum optimization, slower compilation but best performance (recommended for deployment)
# Fast compilation, moderate speedup
config = Pi0Config(
    compile_model=True,
    compile_mode="default"
)

# Maximum performance (slower first run)
config = Pi0Config(
    compile_model=True,
    compile_mode="max-autotune"
)

What Gets Compiled

Different policies compile different components: Pi0/Pi05 Policies:
# From lerobot/policies/pi0/modeling_pi0.py
if config.compile_model:
    self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
    self.forward = torch.compile(self.forward, mode=config.compile_mode)
Diffusion Policy:
# From lerobot/policies/diffusion/modeling_diffusion.py
if config.compile_model:
    self.unet = torch.compile(self.unet, mode=config.compile_mode)
SAC Policy:
# From lerobot/policies/sac/modeling_sac.py
if self.config.use_torch_compile:
    self.critic_ensemble = torch.compile(self.critic_ensemble)
    self.critic_target = torch.compile(self.critic_target)

Compilation Workflow

First inference triggers compilation:
policy = Pi0Policy.from_pretrained("lerobot/pi0_model")
policy.config.compile_model = True
policy.config.compile_mode = "max-autotune"

# First call: slow (compilation happens)
print("First inference (with compilation)...")
start = time.time()
action = policy.predict_action_chunk(observation)
print(f"Time: {time.time() - start:.2f}s")  # e.g., 30s

# Subsequent calls: fast (using compiled graph)
print("Second inference (using compiled graph)...")
start = time.time()
action = policy.predict_action_chunk(observation)
print(f"Time: {time.time() - start:.2f}s")  # e.g., 0.1s

Avoiding Graph Breaks

Some operations cause graph breaks and recompilation. Avoid:
  • Dynamic control flow based on tensor values
  • In-place operations on inputs
  • Python print statements in hot paths
# Bad: Dynamic control flow
if action.mean() > 0.5:  # Graph break!
    action = action * 2

# Good: Use torch operations
action = torch.where(action.mean() > 0.5, action * 2, action)

# Bad: Print in hot path
print(f"Action: {action}")  # Graph break!

# Good: Use torch.compile.disable
with torch.compiler.disable():
    print(f"Action: {action}")

Hardware Acceleration

CUDA (NVIDIA GPUs)

Use CUDA for NVIDIA GPUs:
policy = policy.to("cuda")

# Enable TF32 for faster matmul on Ampere+ GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Enable cuDNN auto-tuner
torch.backends.cudnn.benchmark = True

MPS (Apple Silicon)

Use MPS for Apple Silicon Macs:
if torch.backends.mps.is_available():
    policy = policy.to("mps")
    
    # Note: MPS doesn't support float64
    # DeviceProcessorStep automatically converts float64 -> float32
See lerobot/processor/device_processor.py:108 for MPS dtype handling.

CPU Optimization

Optimize CPU inference:
import torch

# Use all CPU cores
torch.set_num_threads(torch.get_num_threads())

# Enable MKL (Intel CPUs)
torch.backends.mkl.enabled = True

# Compile with max-autotune for CPU
config.compile_model = True
config.compile_mode = "max-autotune"

Mixed Precision

Use mixed precision (float16/bfloat16) for faster training and inference:

Automatic Mixed Precision (AMP)

For training:
import torch
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()
    
    # Forward pass in mixed precision
    with autocast():
        loss = policy.compute_loss(batch)
    
    # Backward pass with gradient scaling
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Manual Precision Control

Control dtype via DeviceProcessorStep:
from lerobot.processor import DeviceProcessorStep

# Use bfloat16 (better for training)
device_processor = DeviceProcessorStep(
    device="cuda",
    float_dtype="bfloat16"
)

# Use float16 (better for inference)
device_processor = DeviceProcessorStep(
    device="cuda",
    float_dtype="float16"
)
See lerobot/processor/device_processor.py:34 for dtype options.

Precision Recommendations

Training: Use bfloat16 for better numerical stabilityInference: Use float16 for maximum speedCPU: Use float32 (CPU doesn’t support bfloat16)

Optimization Checklist

For Inference

import torch

# 1. Enable torch.compile
config.compile_model = True
config.compile_mode = "max-autotune"

# 2. Use appropriate device
policy = policy.to("cuda")

# 3. Enable hardware optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True

# 4. Use mixed precision
device_processor = DeviceProcessorStep(device="cuda", float_dtype="float16")

# 5. Disable gradient computation
torch.set_grad_enabled(False)

# 6. Use eval mode
policy.eval()

For Training

import torch
from torch.cuda.amp import autocast, GradScaler

# 1. Optionally enable torch.compile (slower first epoch)
config.compile_model = True
config.compile_mode = "default"  # Faster compilation

# 2. Use GPU
policy = policy.to("cuda")

# 3. Enable hardware optimizations  
torch.backends.cuda.matmul.allow_tf32 = True

# 4. Use AMP for mixed precision
scaler = GradScaler()

# 5. Use train mode
policy.train()

# 6. Enable cuDNN auto-tuner (for fixed input sizes)
torch.backends.cudnn.benchmark = True

Benchmarking

Measure speedup from optimizations:
import torch
import time

def benchmark_policy(policy, observation, num_iterations=100, warmup=10):
    """Benchmark policy inference speed."""
    
    # Warmup
    for _ in range(warmup):
        _ = policy.predict_action_chunk(observation)
    
    # Benchmark
    torch.cuda.synchronize()  # Wait for GPU
    start = time.time()
    
    for _ in range(num_iterations):
        _ = policy.predict_action_chunk(observation)
    
    torch.cuda.synchronize()
    elapsed = time.time() - start
    
    avg_time = elapsed / num_iterations
    fps = 1.0 / avg_time
    
    print(f"Average time: {avg_time*1000:.2f}ms")
    print(f"Throughput: {fps:.1f} FPS")
    
    return avg_time

# Compare configurations
print("Baseline (float32, no compile):")
policy1 = load_policy(compile_model=False, dtype="float32")
time1 = benchmark_policy(policy1, obs)

print("\nWith torch.compile:")
policy2 = load_policy(compile_model=True, dtype="float32")
time2 = benchmark_policy(policy2, obs)

print("\nWith torch.compile + float16:")
policy3 = load_policy(compile_model=True, dtype="float16")
time3 = benchmark_policy(policy3, obs)

print(f"\nSpeedup from compile: {time1/time2:.2f}x")
print(f"Speedup from compile+fp16: {time1/time3:.2f}x")
Example output:
Baseline (float32, no compile):
Average time: 45.3ms
Throughput: 22.1 FPS

With torch.compile:
Average time: 12.1ms
Throughput: 82.6 FPS

With torch.compile + float16:
Average time: 6.8ms
Throughput: 147.1 FPS

Speedup from compile: 3.74x
Speedup from compile+fp16: 6.66x

Troubleshooting

Issue: Compilation Takes Too Long

Solution: Use faster compile mode or disable for development:
# For development: no compilation
config.compile_model = False

# For faster compilation: use "default" mode
config.compile_mode = "default"

# For deployment: use "max-autotune" and cache
config.compile_mode = "max-autotune"

Issue: Out of Memory with Mixed Precision

Solution: Reduce batch size or use gradient accumulation:
# Reduce batch size
config.batch_size = config.batch_size // 2

# Or use gradient accumulation
accumulation_steps = 2
for i, batch in enumerate(dataloader):
    with autocast():
        loss = policy.compute_loss(batch) / accumulation_steps
    scaler.scale(loss).backward()
    
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Issue: Numerical Instability with float16

Solution: Use bfloat16 or gradient scaling:
# Option 1: Use bfloat16 (better numerical stability)
device_processor = DeviceProcessorStep(device="cuda", float_dtype="bfloat16")

# Option 2: Use GradScaler with float16
from torch.cuda.amp import GradScaler
scaler = GradScaler()

Issue: MPS “float64 not supported” Error

Solution: DeviceProcessorStep handles this automatically:
# This automatically converts float64 -> float32 on MPS
device_processor = DeviceProcessorStep(device="mps")
See lerobot/processor/device_processor.py:108.

Advanced: Custom Compilation

Compile specific functions:
import torch

class MyPolicy:
    def __init__(self):
        # Compile specific methods
        self.process_observation = torch.compile(
            self.process_observation,
            mode="reduce-overhead"
        )
        self.predict_action = torch.compile(
            self.predict_action,
            mode="max-autotune"
        )
    
    def process_observation(self, obs):
        # Hot path: compile for low overhead
        ...
    
    def predict_action(self, obs):
        # Performance critical: compile for max speed
        ...

API Reference

torch.compile Settings

compile_model
bool
default:"false"
Enable torch.compile optimization
compile_mode
str
default:"max-autotune"
Compilation mode: “default”, “reduce-overhead”, or “max-autotune”

DeviceProcessorStep

See lerobot/processor/device_processor.py:34
device
str
Target device: “cpu”, “cuda”, “cuda:0”, “mps”
float_dtype
str | None
Target float dtype: “float16”, “bfloat16”, “float32”, “float64”

Hardware Backends

torch.backends.cuda.matmul.allow_tf32
bool
Enable TF32 for matrix multiplication (Ampere+ GPUs)
torch.backends.cudnn.benchmark
bool
Enable cuDNN auto-tuner for fixed input sizes
torch.backends.cudnn.allow_tf32
bool
Enable TF32 for cuDNN operations

Build docs developers (and LLMs) love