Skip to main content
Subtask annotations break down complex manipulation tasks into interpretable, fine-grained steps. They’re useful for hierarchical policies, reward modeling, and understanding robot behavior.

What are Subtasks?

While a task describes the overall goal (e.g., “Pick up the apple and place it in the basket”), subtasks break execution into atomic steps:
  1. “Approach the apple”
  2. “Grasp the apple”
  3. “Lift the apple”
  4. “Move to basket”
  5. “Release the apple”
Each frame can be annotated with its corresponding subtask, enabling models to learn intermediate stages.

Dataset Structure

Subtask information is stored in meta/subtasks.parquet:
my-dataset/
├── meta/
│   ├── info.json
│   ├── tasks.parquet
│   ├── subtasks.parquet     # Subtask index → subtask string
│   └── ...
└── ...

Subtasks File Format

meta/subtasks.parquet maps indices to descriptions:
subtask_indexsubtask (index column)
0”Approach the apple”
1”Grasp the apple”
2”Lift the apple”
3”Move to basket”
4”Release the apple”

Frame-Level Annotations

Each frame includes a subtask_index field:
{
    "index": 42,
    "timestamp": 1.4,
    "episode_index": 0,
    "task_index": 0,
    "subtask_index": 2,  # References "Lift the apple"
    "observation.state": [...],
    "action": [...],
}

Annotating Datasets

Use the Hugging Face Space to annotate datasets: https://huggingface.co/spaces/lerobot/annotate

Steps:

  1. Load your dataset
  2. Define subtask labels
  3. Annotate frame ranges for each episode
  4. Push annotated dataset to Hub
You can also run the annotator locally: github.com/huggingface/lerobot-annotate

Loading Datasets with Subtasks

from lerobot.datasets.lerobot_dataset import LeRobotDataset

# Load dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")

# Check if subtasks are available
has_subtasks = (
    "subtask_index" in dataset.features
    and dataset.meta.subtasks is not None
)

if has_subtasks:
    print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
    print("Subtasks:")
    for subtask_name in dataset.meta.subtasks.index:
        print(f"  - {subtask_name}")

Accessing Subtask Information

# Access a sample
sample = dataset[100]

print(f"Task: {sample['task']}")              # "Collect the fruit"
print(f"Subtask: {sample['subtask']}")        # "Grasp the apple"
print(f"Task index: {sample['task_index']}")  # tensor(0)
print(f"Subtask index: {sample['subtask_index']}")  # tensor(2)

Using with DataLoader

import torch
from torch.utils.data import DataLoader

dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")

dataloader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
)

for batch in dataloader:
    # Access subtask information
    subtasks = batch["subtask"]  # List of subtask strings
    subtask_indices = batch["subtask_index"]  # Tensor of indices
    
    states = batch["observation.state"]
    actions = batch["action"]
    
    # Use for training
    print(f"Batch subtasks: {set(subtasks)}")

Training with Subtasks

Hierarchical Policy

Predict both actions and current subtask:
import torch.nn as nn

class HierarchicalPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, num_subtasks):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
        )
        self.action_head = nn.Linear(256, action_dim)
        self.subtask_head = nn.Linear(256, num_subtasks)
    
    def forward(self, states):
        features = self.encoder(states)
        actions = self.action_head(features)
        subtask_logits = self.subtask_head(features)
        return actions, subtask_logits

# Training loop
model = HierarchicalPolicy(
    state_dim=14,
    action_dim=14,
    num_subtasks=len(dataset.meta.subtasks)
)

criterion_action = nn.MSELoss()
criterion_subtask = nn.CrossEntropyLoss()

for batch in dataloader:
    states = batch["observation.state"]
    true_actions = batch["action"]
    true_subtasks = batch["subtask_index"]
    
    # Forward pass
    pred_actions, pred_subtask_logits = model(states)
    
    # Compute losses
    action_loss = criterion_action(pred_actions, true_actions)
    subtask_loss = criterion_subtask(pred_subtask_logits, true_subtasks)
    
    total_loss = action_loss + 0.1 * subtask_loss
    
    total_loss.backward()
    optimizer.step()

Stage-Aware Reward Modeling (SARM)

Build reward models that understand task progression:
class SARMRewardModel(nn.Module):
    """Predicts stage (subtask) and progress within that stage."""
    
    def __init__(self, state_dim, num_subtasks):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
        )
        # Discrete: which subtask?
        self.stage_classifier = nn.Linear(256, num_subtasks)
        # Continuous: how far along (0-1)?
        self.progress_regressor = nn.Linear(256, 1)
    
    def forward(self, states):
        features = self.encoder(states)
        stage_logits = self.stage_classifier(features)
        progress = torch.sigmoid(self.progress_regressor(features))
        return stage_logits, progress

# Training
model = SARMRewardModel(
    state_dim=14,
    num_subtasks=len(dataset.meta.subtasks)
)

for batch in dataloader:
    states = batch["observation.state"]
    true_stages = batch["subtask_index"]
    
    # Compute progress within episode
    frame_indices = batch["frame_index"]
    episode_lengths = get_episode_lengths(batch["episode_index"])
    true_progress = frame_indices / episode_lengths
    
    # Forward
    pred_stages, pred_progress = model(states)
    
    # Losses
    stage_loss = F.cross_entropy(pred_stages, true_stages)
    progress_loss = F.mse_loss(pred_progress.squeeze(), true_progress)
    
    loss = stage_loss + progress_loss
    loss.backward()

Subtask Analysis

Distribution Analysis

import matplotlib.pyplot as plt
from collections import Counter

# Count subtask occurrences
subtask_counts = Counter()

for i in range(len(dataset)):
    sample = dataset[i]
    subtask_counts[sample["subtask"]] += 1

# Plot distribution
subtasks = list(subtask_counts.keys())
counts = [subtask_counts[st] for st in subtasks]

plt.figure(figsize=(12, 6))
plt.bar(subtasks, counts)
plt.xlabel("Subtask")
plt.ylabel("Number of Frames")
plt.title("Subtask Distribution")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig("subtask_distribution.png")

Episode-Level Analysis

# Analyze subtask progression in episodes
for ep_idx in range(min(5, dataset.num_episodes)):
    ep = dataset.meta.episodes[ep_idx]
    from_idx = ep["dataset_from_index"]
    to_idx = ep["dataset_to_index"]
    
    print(f"\nEpisode {ep_idx}:")
    print(f"  Task: {ep['tasks']}")
    
    # Track subtask transitions
    prev_subtask = None
    for i in range(from_idx, to_idx):
        sample = dataset[i]
        subtask = sample["subtask"]
        
        if subtask != prev_subtask:
            frame_in_ep = i - from_idx
            print(f"  Frame {frame_in_ep:3d}: {subtask}")
            prev_subtask = subtask

Progress Visualization

Monitor robot execution by tracking subtask progression:
import numpy as np

def visualize_execution(model, dataset, episode_idx):
    """Visualize predicted vs. ground truth subtasks."""
    ep = dataset.meta.episodes[episode_idx]
    from_idx = ep["dataset_from_index"]
    to_idx = ep["dataset_to_index"]
    
    true_subtasks = []
    pred_subtasks = []
    
    for i in range(from_idx, to_idx):
        sample = dataset[i]
        state = sample["observation.state"].unsqueeze(0)
        
        # Get prediction
        with torch.no_grad():
            _, subtask_logits = model(state)
            pred_idx = subtask_logits.argmax(dim=-1).item()
        
        true_idx = sample["subtask_index"].item()
        
        true_subtasks.append(true_idx)
        pred_subtasks.append(pred_idx)
    
    # Plot
    plt.figure(figsize=(15, 5))
    plt.plot(true_subtasks, label="Ground Truth", marker='o')
    plt.plot(pred_subtasks, label="Predicted", marker='x')
    plt.xlabel("Frame")
    plt.ylabel("Subtask Index")
    plt.title(f"Episode {episode_idx} - Subtask Progression")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"episode_{episode_idx}_subtasks.png")
    
    # Print accuracy
    accuracy = np.mean(np.array(true_subtasks) == np.array(pred_subtasks))
    print(f"Subtask prediction accuracy: {accuracy*100:.1f}%")

Subtask-Conditioned Policies

Use ground truth or predicted subtasks as input:
class SubtaskConditionedPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, num_subtasks, embedding_dim=32):
        super().__init__()
        self.subtask_embedding = nn.Embedding(num_subtasks, embedding_dim)
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        self.action_head = nn.Linear(256, action_dim)
    
    def forward(self, states, subtask_indices):
        # Embed subtask
        subtask_emb = self.subtask_embedding(subtask_indices)
        
        # Concatenate state and subtask embedding
        combined = torch.cat([states, subtask_emb], dim=-1)
        
        # Predict action
        features = self.encoder(combined)
        actions = self.action_head(features)
        return actions

# Training
model = SubtaskConditionedPolicy(
    state_dim=14,
    action_dim=14,
    num_subtasks=len(dataset.meta.subtasks)
)

for batch in dataloader:
    states = batch["observation.state"]
    subtasks = batch["subtask_index"]
    true_actions = batch["action"]
    
    pred_actions = model(states, subtasks)
    loss = F.mse_loss(pred_actions, true_actions)
    
    loss.backward()
    optimizer.step()

Example Datasets

Datasets with subtask annotations:
# Load and explore
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")

print(f"Total episodes: {dataset.num_episodes}")
print(f"Total frames: {dataset.num_frames}")
print(f"\nAvailable subtasks:")
for idx, subtask in enumerate(dataset.meta.subtasks.index):
    count = sum(1 for i in range(len(dataset)) if dataset[i]["subtask"] == subtask)
    print(f"  {idx}. {subtask}: {count} frames")

API Reference

Dataset Properties

PropertyTypeDescription
dataset.meta.subtaskspd.DataFrame | NoneSubtask index to name mapping
dataset.features["subtask_index"]dictFeature spec if subtasks present

Sample Keys

KeyTypeDescription
subtask_indextorch.TensorInteger subtask index
subtaskstrNatural language subtask description

Build docs developers (and LLMs) love