ORTModule for PyTorch Integration
ORTModule is the easiest way to accelerate your PyTorch training. It’s a drop-in replacement fortorch.nn.Module that leverages ONNX Runtime’s optimized training backend.
Quick Start
Add just 2 lines to your existing PyTorch training code:Complete Training Example
Here’s a complete MNIST training example showing ORTModule in action:HuggingFace Transformers Example
ORTModule works seamlessly with HuggingFace transformers:Debug Options
For development and debugging, ORTModule provides detailed logging and graph export:Log Levels
- WARNING (default): User-facing warnings and errors
- INFO: Experimental feature stats, more error details
- DEVINFO: Recommended for debugging, includes all rank logs
- VERBOSE: Maximum verbosity, backend and exporter logs
Environment Variables
ORTModule behavior can be customized via environment variables:Fallback Policy
ONNX Opset Version
Save ONNX Models
Memory Optimization
Cache Exported Models
Computation Optimizations
Attention Optimizations
Triton Integration
Custom Autograd Functions
Debugging Options
Performance Optimizations
FusedAdam Optimizer
Replace PyTorch’s AdamW with FusedAdam for faster parameter updates:Combined with DeepSpeed
Combine ORTModule with DeepSpeed for maximum performance:Memory Optimization
Reduce memory usage to train larger models:Memory Optimization Levels
- Level 0 (default): No recomputation
- Level 1: Recompute detected subgraphs (equivalent to PyTorch gradient checkpointing)
- Level 2: Aggressive recomputation including compromised subgraphs
Best Practices
Wrap Order Matters
Recommended: Wrap with ORTModule before other wrappersCompatibility Notes
- ✅ Compatible with
torch.nn.parallel.DistributedDataParallel - ✅ Compatible with DeepSpeed
- ✅ Compatible with PyTorch Lightning
- ❌ NOT compatible with
torch.nn.DataParallel(use DDP instead)
Convergence Debugging
If you encounter convergence issues, collect activation statistics:Performance Benefits
Typical speedups with ORTModule:- BERT-Large: 1.4x faster training
- GPT-2: 1.5x faster training
- Vision Transformers: 1.3-1.6x faster training
- Memory reduction: 20-40% lower peak memory usage with optimization
Next Steps
Distributed Training
Scale ORTModule across multiple GPUs
Training Overview
Learn about other training options