Skip to main content
LeRobot is designed to be extensible. You can implement your own custom policy architectures and leverage LeRobot’s data collection, training infrastructure, and visualization tools.

Policy Interface

All LeRobot policies inherit from PreTrainedPolicy and implement a standard interface:
from lerobot.policies.pretrained import PreTrainedPolicy
import torch.nn as nn

class MyCustomPolicy(PreTrainedPolicy, nn.Module):
    """Custom policy implementation."""
    
    def __init__(self, config, dataset_stats=None):
        super().__init__(config)
        nn.Module.__init__(self)
        
        # Initialize your model architecture
        self.encoder = ...
        self.decoder = ...
    
    def forward(self, batch):
        """Forward pass for training.
        
        Args:
            batch: Dictionary with keys like 'observation.state', 'observation.image', 'action'
            
        Returns:
            loss: Training loss (scalar tensor)
            output_dict: Dictionary with additional outputs for logging
        """
        # Extract inputs
        state = batch['observation.state']
        images = batch['observation.image.side']
        actions = batch['action']
        
        # Forward pass
        encoding = self.encoder(state, images)
        predicted_actions = self.decoder(encoding)
        
        # Compute loss
        loss = nn.functional.mse_loss(predicted_actions, actions)
        
        # Additional outputs for logging
        output_dict = {
            'mse_loss': loss.item(),
        }
        
        return loss, output_dict
    
    def select_action(self, obs):
        """Select action for inference.
        
        Args:
            obs: Dictionary with observation keys
            
        Returns:
            action: Predicted action tensor
        """
        with torch.no_grad():
            state = obs['observation.state']
            images = obs['observation.image.side']
            
            encoding = self.encoder(state, images)
            action = self.decoder(encoding)
            
        return action

Configuration Class

Define a configuration class for your policy:
from lerobot.configs.policies import PreTrainedConfig
from dataclasses import dataclass, field

@dataclass
class MyCustomPolicyConfig(PreTrainedConfig):
    """Configuration for MyCustomPolicy."""
    
    # Required: policy type identifier
    type: str = "my_custom_policy"
    
    # Model architecture parameters
    hidden_dim: int = 256
    num_layers: int = 4
    dropout: float = 0.1
    
    # Training parameters
    optimizer_lr: float = 1e-4
    scheduler_decay_lr: float = 1e-5
    
    # Observation/action parameters (inherited from PreTrainedConfig)
    # input_features: dict
    # output_features: dict
    # device: str

Full Implementation Example

Here’s a complete example implementing a simple MLP policy:
import torch
import torch.nn as nn
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.configs.policies import PreTrainedConfig
from dataclasses import dataclass

@dataclass
class MLPPolicyConfig(PreTrainedConfig):
    type: str = "mlp_policy"
    hidden_dims: list[int] = field(default_factory=lambda: [256, 256, 256])
    activation: str = "relu"
    dropout: float = 0.1

class MLPPolicy(PreTrainedPolicy, nn.Module):
    """Simple MLP policy for robot control."""
    
    def __init__(self, config: MLPPolicyConfig, dataset_stats=None):
        PreTrainedPolicy.__init__(self, config)
        nn.Module.__init__(self)
        
        self.config = config
        
        # Calculate input dimension from features
        input_dim = sum(ft.shape[0] for ft in config.input_features.values())
        output_dim = sum(ft.shape[0] for ft in config.output_features.values())
        
        # Build MLP layers
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in config.hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(self._get_activation(config.activation))
            layers.append(nn.Dropout(config.dropout))
            prev_dim = hidden_dim
        
        # Output layer
        layers.append(nn.Linear(prev_dim, output_dim))
        
        self.network = nn.Sequential(*layers)
        
        # Normalization (optional but recommended)
        self.dataset_stats = dataset_stats
    
    def _get_activation(self, name):
        if name == "relu":
            return nn.ReLU()
        elif name == "tanh":
            return nn.Tanh()
        elif name == "gelu":
            return nn.GELU()
        else:
            raise ValueError(f"Unknown activation: {name}")
    
    def forward(self, batch):
        # Concatenate all input features
        inputs = []
        for key in sorted(self.config.input_features.keys()):
            inputs.append(batch[key])
        x = torch.cat(inputs, dim=-1)
        
        # Forward pass
        predicted_actions = self.network(x)
        
        # Get ground truth actions
        true_actions = batch['action']
        
        # Compute loss
        loss = nn.functional.mse_loss(predicted_actions, true_actions)
        
        output_dict = {
            'mse_loss': loss.item(),
        }
        
        return loss, output_dict
    
    def select_action(self, obs):
        """Select action during inference."""
        with torch.no_grad():
            # Concatenate inputs
            inputs = []
            for key in sorted(self.config.input_features.keys()):
                inputs.append(obs[key])
            x = torch.cat(inputs, dim=-1)
            
            # Predict action
            action = self.network(x)
            
        return action

Registering Your Policy

To use your policy with LeRobot’s training and evaluation scripts, register it in the policy factory:
# In lerobot/policies/factory.py

from lerobot.policies.my_custom_policy.modeling_my_custom_policy import MyCustomPolicy
from lerobot.policies.my_custom_policy.configuration_my_custom_policy import MyCustomPolicyConfig

def get_policy_class(name: str):
    # ... existing policies ...
    
    if name == "my_custom_policy":
        return MyCustomPolicy
    
    # ... rest of function ...

def make_policy_config(policy_type: str, **kwargs):
    # ... existing configs ...
    
    if policy_type == "my_custom_policy":
        return MyCustomPolicyConfig(**kwargs)
    
    # ... rest of function ...

Using Your Custom Policy

Training

Once registered, use your policy with the training CLI:
lerobot-train \
  --policy.type=my_custom_policy \
  --policy.hidden_dims=[512,512,512] \
  --policy.dropout=0.1 \
  --dataset.repo_id=your_username/your_dataset \
  --steps=50000 \
  --batch_size=64

Evaluation

Evaluate your trained policy:
lerobot-eval \
  --policy.path=your_username/my_custom_policy \
  --env.type=gym \
  --env.task=PandaPickPlace-v3 \
  --eval.n_episodes=10

Programmatic Usage

from lerobot.policies.my_custom_policy.modeling_my_custom_policy import MyCustomPolicy
from lerobot.policies.my_custom_policy.configuration_my_custom_policy import MyCustomPolicyConfig

# Create config
config = MyCustomPolicyConfig(
    input_features=input_features,
    output_features=output_features,
    hidden_dims=[512, 512, 512]
)

# Initialize policy
policy = MyCustomPolicy(config)

# Train
for batch in dataloader:
    loss, outputs = policy.forward(batch)
    loss.backward()
    optimizer.step()

# Inference
policy.eval()
action = policy.select_action(observation)

Advanced Features

Vision Encoders

Integrate vision encoders for image observations:
from lerobot.policies.utils import get_image_encoder

class MyVisionPolicy(PreTrainedPolicy, nn.Module):
    def __init__(self, config):
        super().__init__(config)
        
        # Create vision encoder
        self.vision_encoder = get_image_encoder(
            encoder_type="resnet18",
            input_channels=3,
            output_dim=512
        )
        
        # Create policy head
        self.policy_head = nn.Linear(512, action_dim)
    
    def forward(self, batch):
        # Encode images
        images = batch['observation.image.side']
        features = self.vision_encoder(images)
        
        # Predict actions
        actions = self.policy_head(features)
        
        loss = nn.functional.mse_loss(actions, batch['action'])
        return loss, {}

Action Chunking

Implement action chunking for temporal consistency:
class ChunkedPolicy(PreTrainedPolicy, nn.Module):
    def __init__(self, config):
        super().__init__(config)
        self.chunk_size = 16  # Predict 16 future actions
        self.action_dim = config.output_features['action'].shape[0]
        
        self.network = nn.Sequential(
            nn.Linear(obs_dim, 512),
            nn.ReLU(),
            nn.Linear(512, self.action_dim * self.chunk_size)
        )
    
    def forward(self, batch):
        # Predict action chunk
        obs = batch['observation.state']
        action_chunk = self.network(obs)
        action_chunk = action_chunk.reshape(-1, self.chunk_size, self.action_dim)
        
        # Supervise with ground truth chunk
        true_actions = batch['action']  # Shape: [batch, chunk_size, action_dim]
        loss = nn.functional.mse_loss(action_chunk, true_actions)
        
        return loss, {}
    
    def select_action(self, obs):
        """Returns next action from chunk."""
        with torch.no_grad():
            action_chunk = self.network(obs['observation.state'])
            action_chunk = action_chunk.reshape(-1, self.chunk_size, self.action_dim)
            # Return first action in chunk
            return action_chunk[:, 0, :]

Normalization

Use dataset statistics for input/output normalization:
class NormalizedPolicy(PreTrainedPolicy, nn.Module):
    def __init__(self, config, dataset_stats=None):
        super().__init__(config)
        
        # Store normalization statistics
        if dataset_stats:
            self.register_buffer(
                'state_mean',
                dataset_stats['observation.state']['mean']
            )
            self.register_buffer(
                'state_std',
                dataset_stats['observation.state']['std']
            )
            self.register_buffer(
                'action_mean',
                dataset_stats['action']['mean']
            )
            self.register_buffer(
                'action_std',
                dataset_stats['action']['std']
            )
        
        self.network = nn.Sequential(...)
    
    def normalize_state(self, state):
        return (state - self.state_mean) / (self.state_std + 1e-8)
    
    def unnormalize_action(self, action):
        return action * self.action_std + self.action_mean
    
    def forward(self, batch):
        # Normalize inputs
        state = self.normalize_state(batch['observation.state'])
        
        # Predict normalized action
        pred_action_norm = self.network(state)
        
        # Loss in normalized space
        true_action_norm = (batch['action'] - self.action_mean) / (self.action_std + 1e-8)
        loss = nn.functional.mse_loss(pred_action_norm, true_action_norm)
        
        return loss, {}
    
    def select_action(self, obs):
        with torch.no_grad():
            state = self.normalize_state(obs['observation.state'])
            action_norm = self.network(state)
            action = self.unnormalize_action(action_norm)
        return action

Best Practices

1
Follow the standard interface
2
Implement forward() for training and select_action() for inference:
3
def forward(self, batch) -> tuple[torch.Tensor, dict]:
    """Returns (loss, output_dict)."""
    pass

def select_action(self, obs) -> torch.Tensor:
    """Returns action tensor."""
    pass
4
Use configuration dataclasses
5
Define all hyperparameters in a config class:
6
@dataclass
class MyPolicyConfig(PreTrainedConfig):
    type: str = "my_policy"
    hidden_dim: int = 256
    # ... other params
7
Support saving/loading
8
Inherit from PreTrainedPolicy for automatic checkpoint handling:
9
# Save
policy.save_pretrained("outputs/my_policy")

# Load  
policy = MyPolicy.from_pretrained("outputs/my_policy")
10
Add logging outputs
11
Return useful metrics in output_dict for monitoring:
12
output_dict = {
    'mse_loss': mse_loss.item(),
    'action_mean': actions.mean().item(),
    'action_std': actions.std().item(),
}
return loss, output_dict

Testing Your Policy

Test your policy implementation:
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset

# Create dummy config
config = MyCustomPolicyConfig(
    input_features=input_features,
    output_features=output_features
)

# Initialize policy
policy = MyCustomPolicy(config)

# Test forward pass
dataset = LeRobotDataset("lerobot/pusht")
batch = next(iter(torch.utils.data.DataLoader(dataset, batch_size=4)))

loss, outputs = policy.forward(batch)
print(f"Loss: {loss.item()}")
print(f"Outputs: {outputs}")

# Test inference
policy.eval()
obs = batch
action = policy.select_action(obs)
print(f"Action shape: {action.shape}")

Next Steps

Examples

See existing policy implementations:

Build docs developers (and LLMs) love