Skip to main content

ORTModule for PyTorch Integration

ORTModule is the easiest way to accelerate your PyTorch training. It’s a drop-in replacement for torch.nn.Module that leverages ONNX Runtime’s optimized training backend.

Quick Start

Add just 2 lines to your existing PyTorch training code:
from onnxruntime.training.ortmodule import ORTModule

model = build_model()
model = ORTModule(model)  # Wrap your model

# Rest of your training code remains the same
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for data, target in dataloader:
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

Complete Training Example

Here’s a complete MNIST training example showing ORTModule in action:
import torch
from torchvision import datasets, transforms
from onnxruntime.training.ortmodule import ORTModule

class NeuralNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, input1):
        out = self.fc1(input1)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Create model and wrap with ORTModule
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
model = ORTModule(model)

# Setup optimizer and loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

# Load data
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=32, shuffle=True)

# Training loop
model.train()
for epoch in range(5):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data = data.reshape(data.shape[0], -1)
        
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')

HuggingFace Transformers Example

ORTModule works seamlessly with HuggingFace transformers:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW
from onnxruntime.training.ortmodule import ORTModule

# Load pre-trained model
model = AutoModelForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=2
)

# Wrap with ORTModule for acceleration
model = ORTModule(model)
model.to('cuda')

# Standard training setup
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
model.train()
for epoch in range(3):
    for batch in train_dataloader:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to('cuda')
        attention_mask = batch['attention_mask'].to('cuda')
        labels = batch['labels'].to('cuda')
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()

Debug Options

For development and debugging, ORTModule provides detailed logging and graph export:
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel

model = build_model()

# Enable debug options
debug_options = DebugOptions(
    save_onnx=True,              # Export ONNX graphs
    log_level=LogLevel.VERBOSE,  # Detailed logging
    onnx_prefix="model_name"     # Prefix for exported files
)

model = ORTModule(model, debug_options)

Log Levels

  • WARNING (default): User-facing warnings and errors
  • INFO: Experimental feature stats, more error details
  • DEVINFO: Recommended for debugging, includes all rank logs
  • VERBOSE: Maximum verbosity, backend and exporter logs

Environment Variables

ORTModule behavior can be customized via environment variables:

Fallback Policy

# Disable fallback to PyTorch (useful for benchmarking)
export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"

ONNX Opset Version

# Pin to specific ONNX opset version
export ORTMODULE_ONNX_OPSET_VERSION=14

Save ONNX Models

# Export ONNX models for inspection
export ORTMODULE_SAVE_ONNX_PATH="/path/to/output"
export ORTMODULE_LOG_LEVEL="INFO"

Memory Optimization

# Enable gradient checkpointing (level 0-2)
export ORTMODULE_MEMORY_OPT_LEVEL=1

# Enable memory-efficient gradient management
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1

Cache Exported Models

# Cache exported models to reduce startup time
export ORTMODULE_CACHE_DIR="/path/to/cache"

Computation Optimizations

# Enable/disable compute optimizer
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1

# Enable/disable embedding sparse optimizer
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=1

# Enable/disable label sparse optimizer
export ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER=1

Attention Optimizations

# Enable Flash Attention (requires Triton)
export ORTMODULE_USE_FLASH_ATTENTION=1

# Enable efficient attention ATen kernel
export ORTMODULE_USE_EFFICIENT_ATTENTION=1

# Enable scaled dot product attention fallback
export ORTMODULE_ATEN_SDPA_FALLBACK=1

Triton Integration

# Enable OpenAI Triton for kernel execution
export ORTMODULE_USE_TRITON=1

# Specify custom Triton config
export ORTMODULE_TRITON_CONFIG_FILE="triton_config.json"

# Enable kernel tuning
export ORTMODULE_ENABLE_TUNING=1
export ORTMODULE_MAX_TUNING_DURATION_MS=10000
export ORTMODULE_TUNING_RESULTS_PATH="/path/to/results"

# Enable Triton debug mode
export ORTMODULE_TRITON_DEBUG=1

Custom Autograd Functions

# Enable/disable custom autograd functions
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD=1

# Allow gradient checkpointing
export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT=1

Debugging Options

# Print input data sparsity inspection
export ORTMODULE_PRINT_INPUT_DENSITY=1

# Print memory statistics
export ORTMODULE_PRINT_MEMORY_STATS=1

# Control deep copy before export
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1

Performance Optimizations

FusedAdam Optimizer

Replace PyTorch’s AdamW with FusedAdam for faster parameter updates:
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.optim import FusedAdam

model = ORTModule(build_model())
optimizer = FusedAdam(model.parameters(), lr=1e-4)

Combined with DeepSpeed

Combine ORTModule with DeepSpeed for maximum performance:
import deepspeed
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.optim import FusedAdam
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer

# Wrap model with ORTModule first
model = ORTModule(build_model())

# Use FusedAdam
optimizer = FusedAdam(model.parameters(), lr=1e-4)

# Initialize DeepSpeed
model, optimizer, _, lr_scheduler = deepspeed.initialize(
    model=model,
    optimizer=optimizer,
    args=args,
    lr_scheduler=lr_scheduler,
    mpu=mpu,
    dist_init_required=False
)

# Wrap with FP16_Optimizer
optimizer = FP16_Optimizer(optimizer)

Memory Optimization

Reduce memory usage to train larger models:
import os

# Enable gradient checkpointing (level 1 or 2)
os.environ['ORTMODULE_MEMORY_OPT_LEVEL'] = '1'

# Enable memory-efficient gradient management
os.environ['ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT'] = '1'

from onnxruntime.training.ortmodule import ORTModule
model = ORTModule(build_model())

Memory Optimization Levels

  • Level 0 (default): No recomputation
  • Level 1: Recompute detected subgraphs (equivalent to PyTorch gradient checkpointing)
  • Level 2: Aggressive recomputation including compromised subgraphs

Best Practices

Wrap Order Matters

Recommended: Wrap with ORTModule before other wrappers
# Good
model = ORTModule(model)
model = DistributedDataParallel(model)

# Also works
model = ORTModule(model)
model = deepspeed.initialize(...)

Compatibility Notes

  • ✅ Compatible with torch.nn.parallel.DistributedDataParallel
  • ✅ Compatible with DeepSpeed
  • ✅ Compatible with PyTorch Lightning
  • ❌ NOT compatible with torch.nn.DataParallel (use DDP instead)

Convergence Debugging

If you encounter convergence issues, collect activation statistics:
from onnxruntime.training.utils.hooks import (
    GlobalSubscriberManager,
    StatisticsSubscriber
)

model = ORTModule(model)
GlobalSubscriberManager.subscribe(
    model,
    [StatisticsSubscriber(
        output_dir="ort_out",
        override_output_dir=True
    )]
)

Performance Benefits

Typical speedups with ORTModule:
  • BERT-Large: 1.4x faster training
  • GPT-2: 1.5x faster training
  • Vision Transformers: 1.3-1.6x faster training
  • Memory reduction: 20-40% lower peak memory usage with optimization

Next Steps

Distributed Training

Scale ORTModule across multiple GPUs

Training Overview

Learn about other training options