Skip to main content

Overview

VQ-BeT (Vector-Quantized Behavior Transformer) is a novel policy architecture that learns to generate robot behaviors by predicting sequences of discrete action tokens in a learned latent space. It uses a two-stage training process: first learning a residual VQ-VAE to encode action sequences, then training a GPT-style transformer to predict these discrete codes. The policy was introduced in Behavior Generation with Latent Actions and achieves strong performance on manipulation tasks while being more sample-efficient than other approaches.

Key Features

  • Two-Stage Training: Separate VQ-VAE and GPT training phases
  • Residual Vector Quantization: Multi-layer VQ for expressive action encoding
  • Action Chunking with Tokens: Predicts multiple action tokens, each representing a chunk of actions
  • Discrete Latent Space: Enables stable training and efficient exploration
  • Offset Prediction: Fine-tunes discrete predictions with continuous offsets
  • GPT Architecture: Autoregressive transformer for action token prediction

Architecture

VQ-BeT consists of two main components:

1. Residual VQ-VAE (Stage 1)

  • Encoder: Compresses action sequences into latent representations
  • Residual VQ Layer: Multi-layer vector quantization with two codebooks (primary and secondary)
  • Decoder: Reconstructs action sequences from quantized codes
  • Purpose: Learn a discrete latent action space

2. VQ-BeT Policy (Stage 2)

  • Vision Encoder: ResNet backbone with spatial softmax for image features
  • State Projector: Projects proprioceptive state to GPT input dimension
  • GPT Model: Autoregressive transformer that predicts action tokens
  • Prediction Heads:
    • Primary code prediction (first VQ layer)
    • Secondary code prediction (second VQ layer)
    • Offset prediction (continuous refinement)

Training

VQ-BeT requires two-stage training:

Stage 1: VQ-VAE Training

The VQ-VAE is trained for the first n_vqvae_training_steps (default: 20,000 steps).

Stage 2: Policy Training

After VQ-VAE training, the GPT model is trained to predict action codes.

Basic Training Command

lerobot-train \
  --policy=vqbet \
  --dataset.repo_id=lerobot/pusht

Training with Custom Configuration

lerobot-train \
  --policy=vqbet \
  --dataset.repo_id=lerobot/pusht \
  --policy.n_obs_steps=5 \
  --policy.n_action_pred_token=3 \
  --policy.action_chunk_size=5 \
  --policy.n_vqvae_training_steps=20000 \
  --policy.vqvae_n_embed=16 \
  --policy.gpt_n_layer=8 \
  --training.num_epochs=5000 \
  --training.batch_size=256

Python API Training Example

from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
from lerobot.policies.factory import make_pre_post_processors

# Set up
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_id = "lerobot/pusht"

# Configure policy features from dataset
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)

output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}

# Create policy with configuration
cfg = VQBeTConfig(
    input_features=input_features,
    output_features=output_features,
    n_obs_steps=5,
    n_action_pred_token=3,  # Predict 3 action tokens
    action_chunk_size=5,    # Each token represents 5 action steps
    n_vqvae_training_steps=20000,
    vqvae_n_embed=16,       # 16 codes per VQ layer
    vqvae_embedding_dim=256,
    gpt_n_layer=8,
    gpt_n_head=8,
    spatial_softmax_num_keypoints=32
)
policy = VQBeTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)

policy.train()
policy.to(device)

# Set up dataset
def make_delta_timestamps(delta_indices, fps):
    if delta_indices is None:
        return [0]
    return [i / fps for i in delta_indices]

delta_timestamps = {
    "observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
    "action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
delta_timestamps |= {
    k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
    for k in cfg.image_features
}

dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)

# Create optimizer with separate learning rates for VQ-VAE and policy
optimizer_params = policy.get_optim_params()
optimizer = cfg.get_optimizer_preset().build(optimizer_params)
scheduler = cfg.get_scheduler_preset().build(optimizer)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    pin_memory=device.type != "cpu",
    drop_last=True,
)

# Training loop (handles both VQ-VAE and policy training automatically)
step = 0
for batch in dataloader:
    batch = preprocessor(batch)
    loss, output_dict = policy.forward(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    step += 1

Configuration Parameters

Input/Output Structure

n_obs_steps
int
default:"5"
Number of observation steps to pass to the policy.
n_action_pred_token
int
default:"3"
Total number of action tokens to predict (current + future).
action_chunk_size
int
default:"5"
Number of action steps represented by each action token.

Vision Backbone

vision_backbone
str
default:"resnet18"
ResNet variant to use for image encoding.
crop_shape
tuple[int, int] | None
default:"(84, 84)"
(H, W) shape to crop images to. Must fit within image size.
crop_is_random
bool
default:"true"
Whether to use random crops during training (always center crop during eval).
pretrained_backbone_weights
str | None
default:"null"
Pretrained weights from torchvision. None means random initialization.
use_group_norm
bool
default:"true"
Replace batch normalization with group normalization in the backbone.
spatial_softmax_num_keypoints
int
default:"32"
Number of keypoints for spatial softmax operation.

VQ-VAE Configuration

n_vqvae_training_steps
int
default:"20000"
Number of optimization steps for training the Residual VQ-VAE (Stage 1).
vqvae_n_embed
int
default:"16"
Number of embedding vectors in each RVQ codebook layer.
vqvae_embedding_dim
int
default:"256"
Dimension of each embedding vector in the RVQ dictionary.
vqvae_enc_hidden_dim
int
default:"128"
Hidden dimension size for VQ-VAE encoder/decoder.
optimizer_vqvae_lr
float
default:"1e-3"
Learning rate for VQ-VAE training (Stage 1).
optimizer_vqvae_weight_decay
float
default:"1e-4"
Weight decay for VQ-VAE optimizer.

GPT Configuration

gpt_block_size
int
default:"500"
Maximum block size (context length) for the GPT model.
gpt_input_dim
int
default:"512"
Input dimension for GPT (also used as observation feature dimension).
gpt_output_dim
int
default:"512"
Output dimension for GPT (input to prediction heads).
gpt_n_layer
int
default:"8"
Number of transformer layers in the GPT model.
gpt_n_head
int
default:"8"
Number of attention heads in the GPT model.
gpt_hidden_dim
int
default:"512"
Hidden dimension size for GPT feed-forward layers.
dropout
float
default:"0.1"
Dropout rate for GPT model.

Loss Weights

offset_loss_weight
float
default:"10000.0"
Weight multiplier for the continuous offset prediction loss.
primary_code_loss_weight
float
default:"5.0"
Weight multiplier for the primary VQ code prediction loss.
secondary_code_loss_weight
float
default:"0.5"
Weight multiplier for the secondary VQ code prediction loss.

Inference

bet_softmax_temperature
float
default:"0.1"
Sampling temperature for code selection during inference. Lower = more deterministic.
sequentially_select
bool
default:"false"
Whether to select primary code first, then secondary (true), or sample both jointly (false).

Optimization

optimizer_lr
float
default:"1e-4"
Learning rate for policy training (Stage 2).
optimizer_betas
tuple
default:"(0.95, 0.999)"
Beta parameters for Adam optimizer.
optimizer_eps
float
default:"1e-8"
Epsilon value for Adam optimizer.
optimizer_weight_decay
float
default:"1e-6"
Weight decay for policy optimizer.
scheduler_warmup_steps
int
default:"500"
Number of warmup steps for learning rate scheduler.

Normalization

normalization_mapping
dict
Normalization mode for each feature type. Note: VQ-BeT uses IDENTITY for visual features. Default: {"VISUAL": "IDENTITY", "STATE": "MIN_MAX", "ACTION": "MIN_MAX"}

Usage Example

Loading a Pretrained Model

from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy

# Load from Hugging Face Hub
policy = VQBeTPolicy.from_pretrained("lerobot/vqbet_pusht")

# Use for inference
policy.eval()
with torch.no_grad():
    action = policy.select_action(observation)

Inference with Temperature Control

from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy

cfg = VQBeTConfig(
    input_features=input_features,
    output_features=output_features,
    bet_softmax_temperature=0.01,  # Lower temperature for more deterministic behavior
    sequentially_select=True  # Select codes sequentially
)
policy = VQBeTPolicy(cfg)

Inference Loop with Observation Queue

# Reset policy queues when environment resets
policy.reset()

# Run episode
for step in range(episode_length):
    # Policy maintains observation queue internally
    action = policy.select_action(observation)
    observation, reward, done, info = env.step(action)
    
    if done:
        policy.reset()

Understanding VQ-BeT

Two-Stage Training Process

Stage 1 (Steps 0-20,000): Train VQ-VAE
  • Learn to encode action sequences into discrete codes
  • Reconstruct actions from quantized representations
  • Build a discrete latent action space
Stage 2 (Steps 20,000+): Train Policy
  • Freeze VQ-VAE weights
  • Train GPT to predict action codes from observations
  • Learn offset predictions for fine-grained control

Key Concepts

Action Tokens: Each token represents a chunk of action_chunk_size continuous actions encoded as two discrete codes (primary + secondary) plus a continuous offset. Residual VQ: Uses multiple VQ layers where each layer encodes the residual from previous layers, enabling more expressive representations. Autoregressive Prediction: The GPT model predicts action tokens one at a time, conditioned on previous tokens and observations.

Advantages

  • Sample Efficiency: Discrete latent space enables more stable learning
  • Multimodal: Can represent multiple valid behaviors
  • Hierarchical: Two-level VQ captures both coarse and fine-grained actions
  • Interpretable: Discrete codes provide insight into learned behaviors

Training Tips

  • Ensure VQ-VAE is well-trained before policy training begins (monitor reconstruction loss)
  • Adjust n_vqvae_training_steps if VQ-VAE hasn’t converged by step 20,000
  • Use larger vqvae_n_embed for more complex tasks
  • Tune bet_softmax_temperature for the right exploration/exploitation balance
  • Balance the three loss weights based on task requirements

File Locations

Source files in the LeRobot repository:
  • Configuration: src/lerobot/policies/vqbet/configuration_vqbet.py
  • Model: src/lerobot/policies/vqbet/modeling_vqbet.py
  • Processor: src/lerobot/policies/vqbet/processor_vqbet.py
  • Utilities: src/lerobot/policies/vqbet/vqbet_utils.py

Citation

@article{lee2024behavior,
  title={Behavior generation with latent actions},
  author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
  journal={arXiv preprint arXiv:2403.03181},
  year={2024}
}

Additional Resources

Build docs developers (and LLMs) love