Skip to main content

Distributed Training

ONNX Runtime Training seamlessly integrates with popular distributed training frameworks to scale training across multiple GPUs and nodes. This guide covers setup and best practices for distributed training with ORTModule.

Supported Frameworks

ORTModule works with:
  • PyTorch DDP (DistributedDataParallel): Native PyTorch multi-GPU training
  • DeepSpeed: Memory-efficient training with ZeRO optimizer
  • DeepSpeed Pipeline Parallelism: Model parallelism for very large models
  • PyTorch FSDP: Fully Sharded Data Parallel
  • Horovod: Multi-framework distributed training

PyTorch DistributedDataParallel (DDP)

Basic Setup

Wrap your model with ORTModule before DDP:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from onnxruntime.training.ortmodule import ORTModule

def setup(rank, world_size):
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank
    )

def train(rank, world_size):
    setup(rank, world_size)
    
    # Build model
    model = build_model().to(rank)
    
    # Wrap with ORTModule first
    model = ORTModule(model)
    
    # Then wrap with DDP
    model = DDP(model, device_ids=[rank])
    
    # Standard training loop
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    for epoch in range(num_epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            
            inputs = batch['input'].to(rank)
            labels = batch['label'].to(rank)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
    
    dist.destroy_process_group()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    torch.multiprocessing.spawn(
        train,
        args=(world_size,),
        nprocs=world_size,
        join=True
    )

Launch Script

# Single node, 8 GPUs
python -m torch.distributed.launch \
    --nproc_per_node=8 \
    --use_env \
    train.py

# Multi-node training
python -m torch.distributed.launch \
    --nproc_per_node=8 \
    --nnodes=4 \
    --node_rank=$NODE_RANK \
    --master_addr=$MASTER_ADDR \
    --master_port=$MASTER_PORT \
    --use_env \
    train.py

Complete DDP Example

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from onnxruntime.training.ortmodule import ORTModule

class Trainer:
    def __init__(self, model, train_dataset, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        
        # Setup distributed
        self.setup_distributed()
        
        # Wrap model
        self.model = model.to(rank)
        self.model = ORTModule(self.model)
        self.model = DDP(self.model, device_ids=[rank])
        
        # Setup data loader with distributed sampler
        self.sampler = DistributedSampler(
            train_dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True
        )
        
        self.dataloader = DataLoader(
            train_dataset,
            batch_size=32,
            sampler=self.sampler,
            num_workers=4,
            pin_memory=True
        )
        
        # Setup optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=1e-4
        )
        
        self.criterion = torch.nn.CrossEntropyLoss()
    
    def setup_distributed(self):
        dist.init_process_group(
            backend="nccl",
            init_method="env://"
        )
    
    def train_epoch(self, epoch):
        self.model.train()
        self.sampler.set_epoch(epoch)  # Shuffle differently each epoch
        
        total_loss = 0.0
        for batch_idx, (data, target) in enumerate(self.dataloader):
            data = data.to(self.rank)
            target = target.to(self.rank)
            
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0 and self.rank == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        # Reduce loss across all processes
        total_loss_tensor = torch.tensor(total_loss).to(self.rank)
        dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM)
        avg_loss = total_loss_tensor.item() / (len(self.dataloader) * self.world_size)
        
        if self.rank == 0:
            print(f"Epoch {epoch}, Average Loss: {avg_loss:.4f}")
    
    def cleanup(self):
        dist.destroy_process_group()

def main():
    rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    
    model = build_model()
    dataset = load_dataset()
    
    trainer = Trainer(model, dataset, rank, world_size)
    
    for epoch in range(10):
        trainer.train_epoch(epoch)
    
    trainer.cleanup()

if __name__ == "__main__":
    main()

DeepSpeed Integration

DeepSpeed provides memory-efficient training through ZeRO optimizer stages.

Basic DeepSpeed Setup

import deepspeed
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.optim import FusedAdam
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer

def create_model_and_optimizer():
    model = build_model()
    
    # Wrap with ORTModule first
    model = ORTModule(model)
    
    # Use FusedAdam for better performance
    optimizer = FusedAdam(model.parameters(), lr=1e-4)
    
    return model, optimizer

def train():
    model, optimizer = create_model_and_optimizer()
    
    # DeepSpeed configuration
    ds_config = {
        "train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": {
            "enabled": True,
            "loss_scale": 0,
            "initial_scale_power": 16
        },
        "zero_optimization": {
            "stage": 2,
            "offload_optimizer": {
                "device": "cpu"
            }
        }
    }
    
    # Initialize DeepSpeed
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        config=ds_config
    )
    
    # Optionally wrap with FP16_Optimizer
    optimizer = FP16_Optimizer(optimizer)
    
    # Training loop
    for epoch in range(num_epochs):
        for batch in dataloader:
            inputs = batch['input'].to(model_engine.local_rank)
            labels = batch['label'].to(model_engine.local_rank)
            
            outputs = model_engine(inputs)
            loss = criterion(outputs, labels)
            
            model_engine.backward(loss)
            model_engine.step()

DeepSpeed Configuration File

{
  "train_batch_size": 32,
  "train_micro_batch_size_per_gpu": 4,
  "gradient_accumulation_steps": 8,
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": 1e-4,
      "weight_decay": 0.01,
      "betas": [0.9, 0.999],
      "eps": 1e-8
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 0,
      "warmup_max_lr": 1e-4,
      "warmup_num_steps": 1000
    }
  },
  "fp16": {
    "enabled": true,
    "loss_scale": 0,
    "initial_scale_power": 16,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 2e8,
    "reduce_scatter": true,
    "reduce_bucket_size": 2e8,
    "overlap_comm": true,
    "contiguous_gradients": true
  },
  "wall_clock_breakdown": false
}

Launch with DeepSpeed

deepspeed --num_gpus=8 train.py \
    --deepspeed \
    --deepspeed_config ds_config.json

DeepSpeed Pipeline Parallelism

For models too large for single GPU, use pipeline parallelism:
from onnxruntime.training.ortmodule import DebugOptions, LogLevel
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule
import deepspeed

def create_pipeline_model():
    # Define model layers
    layers = [
        nn.Linear(1024, 2048),
        nn.ReLU(),
        nn.Linear(2048, 2048),
        nn.ReLU(),
        nn.Linear(2048, 1024),
        # ... more layers
    ]
    
    # Debug options (optional)
    debug_options = DebugOptions(
        save_onnx=True,
        log_level=LogLevel.INFO,
        onnx_prefix="pipeline_model"
    )
    
    # Create pipeline module
    pipeline_model = ORTPipelineModule(
        layers,
        num_stages=4,  # Partition across 4 GPUs
        partition_method="parameters",
        base_seed=1234,
        debug_options=debug_options
    )
    
    return pipeline_model

def train_pipeline():
    model = create_pipeline_model()
    
    # Initialize DeepSpeed with pipeline config
    model_engine, optimizer, _, _ = deepspeed.initialize(
        model=model,
        model_parameters=[p for p in model.parameters()],
        config="pipeline_config.json"
    )
    
    # Training loop
    for batch in dataloader:
        loss = model_engine(batch)
        model_engine.backward(loss)
        model_engine.step()

Data Loading Best Practices

Use DistributedSampler

Ensure each process gets different data:
from torch.utils.data import DataLoader, DistributedSampler

sampler = DistributedSampler(
    dataset,
    num_replicas=world_size,
    rank=rank,
    shuffle=True,
    drop_last=True
)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

# Update sampler epoch for proper shuffling
for epoch in range(num_epochs):
    sampler.set_epoch(epoch)
    for batch in dataloader:
        # training code

Load Balancing for Variable Length Sequences

For NLP and speech tasks with variable length inputs:
from onnxruntime.training.utils.data import (
    LoadBalancingDistributedSampler,
    LoadBalancingDistributedBatchSampler
)

# Define complexity function (e.g., sequence length)
def complexity_fn(sample):
    return len(sample['input_ids'])

# Define batch function
def batch_fn(samples):
    # Custom batching logic
    return collate_fn(samples)

# Create load-balanced sampler
sampler = LoadBalancingDistributedSampler(
    dataset,
    complexity_fn=complexity_fn
)

batch_sampler = LoadBalancingDistributedBatchSampler(
    sampler,
    batch_fn=batch_fn
)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_sampler=batch_sampler
)

for epoch in range(num_epochs):
    batch_sampler.set_epoch(epoch)
    for batch in loader:
        # training code
This helps avoid the “straggler problem” where some GPUs finish faster than others.

Environment Variables for Distributed Training

Essential Variables

# PyTorch DDP
export MASTER_ADDR="localhost"
export MASTER_PORT="29500"
export WORLD_SIZE=8
export RANK=0
export LOCAL_RANK=0

# NCCL tuning
export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME=eth0

ORTModule Distributed Settings

# Disable fallback for consistent performance
export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"

# Enable memory optimizations
export ORTMODULE_MEMORY_OPT_LEVEL=1
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1

# Cache exported models (useful for multi-node)
export ORTMODULE_CACHE_DIR="/shared/cache"

Checkpoint Saving and Loading

Save Checkpoints (Rank 0 only)

import torch.distributed as dist

def save_checkpoint(model, optimizer, epoch, path):
    if dist.get_rank() == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.module.state_dict(),  # .module for DDP
            'optimizer_state_dict': optimizer.state_dict(),
        }
        torch.save(checkpoint, path)
        print(f"Checkpoint saved to {path}")

def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path, map_location=f'cuda:{dist.get_rank()}')
    model.module.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

Testing and Debugging

Test Distributed Setup

import torch
import torch.distributed as dist

def test_distributed():
    # Initialize process group
    dist.init_process_group(backend="nccl")
    
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    
    print(f"Rank {rank}/{world_size} initialized")
    
    # Test all-reduce
    tensor = torch.tensor([rank]).cuda()
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    
    expected = sum(range(world_size))
    assert tensor.item() == expected, f"All-reduce failed: {tensor.item()} != {expected}"
    
    print(f"Rank {rank} passed all-reduce test")
    
    dist.destroy_process_group()

if __name__ == "__main__":
    test_distributed()

Enable Detailed Logging

export ORTMODULE_LOG_LEVEL=DEVINFO
export NCCL_DEBUG=INFO
export TORCH_DISTRIBUTED_DEBUG=DETAIL

Performance Tips

  1. Wrap Order: Always wrap with ORTModule before DDP/DeepSpeed
  2. Batch Size: Use largest batch size that fits in memory
  3. Gradient Accumulation: Simulate larger batches with accumulation
  4. Mixed Precision: Enable FP16 training for faster computation
  5. Communication Backend: Use NCCL for GPU training, Gloo for CPU
  6. Pin Memory: Enable pin_memory=True in DataLoader
  7. Persistent Workers: Set persistent_workers=True to avoid respawning
  8. NCCL Tuning: Optimize NCCL settings for your network topology

Common Issues

Hanging on Initialization

# Check network connectivity
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

# Use different port
export MASTER_PORT=29501

Out of Memory

# Enable memory optimizations
export ORTMODULE_MEMORY_OPT_LEVEL=2
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1

# Reduce batch size or enable gradient accumulation

Gradient Synchronization Issues

# Find unused parameters
model = DDP(model, device_ids=[rank], find_unused_parameters=True)

Next Steps

ORTModule

Learn more about ORTModule features

Training Overview

Explore other training options