Training Overview
ONNX Runtime provides multiple approaches to accelerate model training, from seamless PyTorch integration to low-level on-device training APIs.Training Approaches
ONNX Runtime offers two main training solutions:1. ORTModule for PyTorch
The easiest way to accelerate existing PyTorch training scripts. Simply wrap yourtorch.nn.Module with ORTModule to leverage ONNX Runtime’s optimized training backend.
- Drop-in replacement for PyTorch modules
- Automatic graph optimization and kernel fusion
- Memory optimization with gradient checkpointing
- Compatible with distributed training frameworks (DeepSpeed, DDP)
- Supports mixed precision training
- Large language model training
- Computer vision model training
- Fine-tuning pre-trained models
- Distributed training workloads
2. On-Device Training API
Lightweight training API designed for edge devices and mobile platforms. Enables training directly on resource-constrained devices.- Minimal dependencies and small binary size
- Optimized for mobile and edge devices
- Pre-compiled ONNX models for faster startup
- Cross-platform support (iOS, Android, embedded)
- Federated learning
- Personalized model adaptation
- Privacy-preserving on-device learning
- Edge AI applications
Performance Benefits
ONNX Runtime training delivers significant performance improvements:Speed Improvements
- 1.3-2x faster training for large transformer models
- Optimized memory usage enables larger batch sizes
- Reduced memory footprint through gradient checkpointing
- Efficient mixed precision training (FP16/BF16)
Optimization Techniques
- Graph optimizations: Operator fusion, constant folding, redundant computation elimination
- Memory optimizations: Recomputation, memory-efficient gradient management
- Kernel optimizations: Fused kernels for common patterns (attention, layer norm, etc.)
- Data sparsity optimizations: Embedding sparse optimizer, label sparse optimizer
Optimizers
ONNX Runtime provides optimized implementations of common optimizers:FusedAdam
Accelerated Adam optimizer using multi-tensor apply for batch gradient updates:FP16_Optimizer
Complements DeepSpeed and Apex for improved mixed precision training:Integration with Popular Frameworks
ORTModule integrates seamlessly with popular training frameworks:DeepSpeed
PyTorch DDP (Distributed Data Parallel)
PyTorch Lightning
Installation
From PyPI (Recommended)
With CUDA Support
Configure ORTModule Extensions
Getting Started
- For PyTorch users: Start with ORTModule to accelerate existing training scripts
- For edge deployment: Use the On-Device Training API for mobile and embedded devices
- For distributed training: Check out Distributed Training setup guides
Next Steps
ORTModule
Accelerate PyTorch training with a simple wrapper
On-Device Training
Train models on edge devices and mobile platforms
Distributed Training
Scale training across multiple GPUs and nodes