Skip to main content

Overview

The MDM class implements a Motion Diffusion Model for generating kinematic character motion sequences. It uses transformer-based denoising with diffusion processes to generate realistic motion conditioned on previous states, terrain heightmaps, and target objectives.

Class Definition

class MDM(MotionGenerator)
Inherits from MotionGenerator.

Constructor

MDM(cfg: dict)
cfg
dict
required
Configuration dictionary containing all model parameters and settings.

Configuration Keys

cfg.device
str
required
Device to run the model on (e.g., ‘cuda’, ‘cpu’)
cfg.char_file
str
required
Path to the character model XML file
cfg.diffusion_timesteps
int
required
Number of diffusion timesteps (typically 100-1000)
cfg.epochs
int
required
Number of training epochs
cfg.batch_size
int
required
Batch size for training
cfg.dropout
float
required
Dropout rate for the transformer
cfg.lr
float
required
Learning rate for the optimizer
cfg.d_model
int
required
Transformer model dimension
cfg.num_heads
int
required
Number of attention heads in transformer
cfg.num_layers
int
required
Number of transformer encoder layers
cfg.seq_len
int
required
Length of generated motion sequences
cfg.num_prev_states
int
required
Number of previous states to condition on
cfg.use_heightmap_obs
bool
required
Whether to use terrain heightmap observations
cfg.use_target_obs
bool
required
Whether to use target position/direction observations
cfg.target_type
str
required
Type of target conditioning (see TargetType enum)
cfg.predict_mode
str
required
Prediction mode (see PredictMode enum)
cfg.test_mode
str
required
Generation mode for testing (see GenerationMode enum)
cfg.features
dict
required
Feature configuration including rotation type and frame components

Enumerations

PredictMode

Defines what the denoising model predicts during reverse diffusion.
class PredictMode(enum.Enum):
    PREDICT_X0 = 0      # Predict clean sample x_0
    PREDICT_NOISE = 1   # Predict noise component

LossType

Defines different loss components used during training.
class LossType(enum.Enum):
    SIMPLE_ROOT_POS_LOSS = 0
    SIMPLE_ROOT_ROT_LOSS = 1
    SIMPLE_JOINT_ROT_LOSS = 2
    SIMPLE_CONTACT_LOSS = 3
    VEL_ROOT_POS_LOSS = 4
    VEL_ROOT_ROT_LOSS = 5
    VEL_JOINT_ROT_LOSS = 6
    FK_BODY_POS_LOSS = 7
    FK_BODY_ROT_LOSS = 8
    FOOT_CONTACT_LOSS = 9
    TARGET_LOSS = 10
    HF_COLLISION_LOSS = 11
    SIMPLE_FLOOR_HEIGHT_LOSS = 12
    SIMPLE_BODY_POS_LOSS = 13
    BODY_POS_CONSISTENCY_LOSS = 14
    VEL_FK_BODY_POS_LOSS = 15

GenerationMode

Defines the sampling method for motion generation.
class GenerationMode(enum.Enum):
    MODE_REVERSE_DIFFUSION = 0  # Full reverse diffusion process
    MODE_DDIM = 1               # DDIM sampling (faster)
    MODE_ECM = 2                # ECM sampling
    NONE = 3                    # No sampling (training only)

TargetType

Defines the type of target conditioning.
class TargetType(enum.Enum):
    XY_POS = 0                      # XY position only
    XY_POS_AND_HEADING = 1          # XY position + heading angle
    XY_POS_AND_DELTA_FLOOR_Z = 2    # XY position + floor height delta
    XY_DIR = 3                       # XY direction vector

Key Methods

gen_sequence

Generates a motion sequence given conditions.
gen_sequence(
    conds: dict,
    ddim_stride: int = None,
    mode: GenerationMode = GenerationMode.MODE_DDIM
) -> tuple[torch.Tensor, dict]
conds
dict
required
Dictionary containing conditioning information:
  • MDMKeyType.PREV_STATE_KEY: Previous motion states (dict of MDMFrameType components)
  • MDMKeyType.OBS_KEY: Heightmap observations (if enabled)
  • MDMKeyType.TARGET_KEY: Target position/direction (if enabled)
ddim_stride
int
Stride for DDIM sampling (smaller = higher quality, slower)
mode
GenerationMode
Generation mode to use
Returns: Tuple of (motion_data, info) where motion_data is a dict of motion components.

train

Trains the diffusion model on motion data.
train(
    motion_sampler: MotionSampler,
    checkpoint_dir: Path = None,
    test_only: bool = False,
    stats_filepath: str = None
)
motion_sampler
MotionSampler
required
Motion sampler for loading training data
checkpoint_dir
Path
Directory to save model checkpoints
test_only
bool
If True, only evaluate without updating weights
stats_filepath
str
Path to save/load feature normalization statistics

compute_stats

Computes mean and standard deviation statistics for motion features.
compute_stats(
    sampler: MotionSampler,
    stats_filepath: str = None
)
sampler
MotionSampler
required
Motion sampler containing the dataset
stats_filepath
str
Path to save/load statistics file

forward_diffusion

Applies forward diffusion process to add noise to clean samples.
forward_diffusion(
    x_0: torch.Tensor,
    t: torch.Tensor,
    noise: torch.Tensor = None
) -> torch.Tensor
x_0
torch.Tensor
required
Clean motion samples of shape [batch_size, seq_len, motion_dim]
t
torch.Tensor
required
Timestep indices of shape [batch_size, 1, 1]
noise
torch.Tensor
Optional noise tensor. If None, random noise is sampled.
Returns: Noised samples at timestep t.

reverse_diffusion

Performs reverse diffusion to generate samples from noise.
reverse_diffusion(
    conds: dict,
    noise: torch.Tensor = None,
    start_timestep: int = None,
    end_timestep: int = None,
    keep_all_samples: bool = False,
    stride: int = 1
) -> torch.Tensor | list[torch.Tensor]
conds
dict
required
Conditioning information dictionary
noise
torch.Tensor
Initial noise tensor. If None, random noise is sampled.
start_timestep
int
Starting timestep (defaults to diffusion_timesteps)
end_timestep
int
Ending timestep (defaults to 0)
keep_all_samples
bool
If True, returns all intermediate samples
stride
int
Timestep stride for sampling
Returns: Generated motion samples (or list if keep_all_samples=True).

ddim_inference

Performs DDIM inference for faster sampling.
ddim_inference(
    conds: dict,
    noise: torch.Tensor = None,
    stride: int = 2,
    keep_all_samples: bool = False
) -> torch.Tensor | list[torch.Tensor]
conds
dict
required
Conditioning information dictionary
noise
torch.Tensor
Initial noise tensor. If None, random noise is sampled.
stride
int
Stride between timesteps (larger = faster but lower quality)
keep_all_samples
bool
If True, returns all intermediate samples
Returns: Generated motion samples (or list if keep_all_samples=True).

assemble_mdm_features

Assembles MDM features from motion component dictionary.
assemble_mdm_features(
    ml_component_dict: dict,
    standardize: bool = True
) -> torch.Tensor
ml_component_dict
dict
required
Dictionary mapping MDMFrameType to component tensors
standardize
bool
Whether to apply standardization using computed statistics
Returns: Assembled feature tensor of shape [batch_size, seq_len, motion_dim].

extract_motion_features

Extracts motion components from MDM feature tensor.
extract_motion_features(
    mdm_features: torch.Tensor,
    unstandardize: bool = True
) -> dict
mdm_features
torch.Tensor
required
MDM feature tensor of shape [batch_size, seq_len, motion_dim]
unstandardize
bool
Whether to unstandardize features using computed statistics
Returns: Dictionary mapping MDMFrameType to component tensors.

Usage Example

import torch
from parc.motion_generator.mdm import MDM, GenerationMode, TargetType
from parc.motion_generator.motion_sampler import MotionSampler
from parc.motion_generator.diffusion_util import MDMKeyType, MDMFrameType

# Configuration
cfg = {
    'device': 'cuda',
    'char_file': 'path/to/character.xml',
    'diffusion_timesteps': 100,
    'epochs': 1000,
    'batch_size': 32,
    'dropout': 0.1,
    'lr': 1e-4,
    'd_model': 512,
    'num_heads': 8,
    'num_layers': 8,
    'seq_len': 64,
    'num_prev_states': 2,
    'use_heightmap_obs': True,
    'use_target_obs': True,
    'target_type': 'XY_DIR',
    'predict_mode': 'PREDICT_X0',
    'test_mode': 'MODE_DDIM',
    'test_ddim_stride': 10,
    'features': {
        'rot_type': 'ROT_6D',
        'frame_components': ['ROOT_POS', 'ROOT_ROT', 'JOINT_ROT', 'CONTACTS']
    },
    # ... additional configuration
}

# Initialize model
mdm = MDM(cfg)

# Train the model
motion_sampler = MotionSampler(sampler_cfg)
mdm.train(motion_sampler, checkpoint_dir='checkpoints/')

# Generate motion
conds = {
    MDMKeyType.PREV_STATE_KEY: prev_motion_dict,  # Previous motion states
    MDMKeyType.OBS_KEY: heightmap_tensor,          # Terrain heightmap
    MDMKeyType.TARGET_KEY: target_direction,       # Target direction
    MDMKeyType.PREV_STATE_FLAG_KEY: torch.ones(batch_size, dtype=torch.bool),
    MDMKeyType.OBS_FLAG_KEY: torch.ones(batch_size, dtype=torch.bool),
    MDMKeyType.TARGET_FLAG_KEY: torch.ones(batch_size, dtype=torch.bool),
}

generated_motion, info = mdm.gen_sequence(
    conds,
    ddim_stride=10,
    mode=GenerationMode.MODE_DDIM
)

Helper Functions

pseudo_huber_loss_fn

pseudo_huber_loss_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor
Computes pseudo-Huber loss between two tensors. More robust to outliers than L2 loss.

squared_l2_loss_fn

squared_l2_loss_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor
Computes squared L2 loss between two tensors.

get_dir_from_motion

get_dir_from_motion(
    motions: torch.Tensor,
    eps: float = 0.05
) -> torch.Tensor
Extracts direction vector from motion trajectory.

get_dir_from_canonicalized_pos_and_rot

get_dir_from_canonicalized_pos_and_rot(
    pos: torch.Tensor,
    rot: torch.Tensor,
    pos_eps: float = 0.1,
    heading_eps: float = 0.25
) -> torch.Tensor
Computes target direction from canonicalized position and rotation, handling standing-still cases.

See Also

Build docs developers (and LLMs) love