Skip to main content

Training Overview

ONNX Runtime provides multiple approaches to accelerate model training, from seamless PyTorch integration to low-level on-device training APIs.

Training Approaches

ONNX Runtime offers two main training solutions:

1. ORTModule for PyTorch

The easiest way to accelerate existing PyTorch training scripts. Simply wrap your torch.nn.Module with ORTModule to leverage ONNX Runtime’s optimized training backend.
from onnxruntime.training.ortmodule import ORTModule

model = build_model()
model = ORTModule(model)
Key Features:
  • Drop-in replacement for PyTorch modules
  • Automatic graph optimization and kernel fusion
  • Memory optimization with gradient checkpointing
  • Compatible with distributed training frameworks (DeepSpeed, DDP)
  • Supports mixed precision training
Use Cases:
  • Large language model training
  • Computer vision model training
  • Fine-tuning pre-trained models
  • Distributed training workloads

2. On-Device Training API

Lightweight training API designed for edge devices and mobile platforms. Enables training directly on resource-constrained devices.
from onnxruntime.training.api import Module, Optimizer, CheckpointState

state = CheckpointState.load_checkpoint("checkpoint.ckpt")
model = Module("training_model.onnx", state)
optimizer = Optimizer("optimizer.onnx", model)
Key Features:
  • Minimal dependencies and small binary size
  • Optimized for mobile and edge devices
  • Pre-compiled ONNX models for faster startup
  • Cross-platform support (iOS, Android, embedded)
Use Cases:
  • Federated learning
  • Personalized model adaptation
  • Privacy-preserving on-device learning
  • Edge AI applications

Performance Benefits

ONNX Runtime training delivers significant performance improvements:

Speed Improvements

  • 1.3-2x faster training for large transformer models
  • Optimized memory usage enables larger batch sizes
  • Reduced memory footprint through gradient checkpointing
  • Efficient mixed precision training (FP16/BF16)

Optimization Techniques

  • Graph optimizations: Operator fusion, constant folding, redundant computation elimination
  • Memory optimizations: Recomputation, memory-efficient gradient management
  • Kernel optimizations: Fused kernels for common patterns (attention, layer norm, etc.)
  • Data sparsity optimizations: Embedding sparse optimizer, label sparse optimizer

Optimizers

ONNX Runtime provides optimized implementations of common optimizers:

FusedAdam

Accelerated Adam optimizer using multi-tensor apply for batch gradient updates:
from onnxruntime.training.optim import FusedAdam

optimizer = FusedAdam(model.parameters(), lr=1e-4)

FP16_Optimizer

Complements DeepSpeed and Apex for improved mixed precision training:
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer

optimizer = FP16_Optimizer(optimizer)
ORTModule integrates seamlessly with popular training frameworks:

DeepSpeed

from onnxruntime.training.ortmodule import ORTModule

model = ORTModule(model)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
    model=model,
    optimizer=optimizer,
    args=args,
    lr_scheduler=lr_scheduler
)

PyTorch DDP (Distributed Data Parallel)

from torch.nn.parallel import DistributedDataParallel as DDP
from onnxruntime.training.ortmodule import ORTModule

model = ORTModule(model)
model = DDP(model, device_ids=[local_rank])

PyTorch Lightning

import pytorch_lightning as pl
from onnxruntime.training.ortmodule import ORTModule

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = ORTModule(build_model())

Installation

pip install onnxruntime-training

With CUDA Support

pip install onnxruntime-training-gpu

Configure ORTModule Extensions

python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install

Getting Started

  1. For PyTorch users: Start with ORTModule to accelerate existing training scripts
  2. For edge deployment: Use the On-Device Training API for mobile and embedded devices
  3. For distributed training: Check out Distributed Training setup guides

Next Steps

ORTModule

Accelerate PyTorch training with a simple wrapper

On-Device Training

Train models on edge devices and mobile platforms

Distributed Training

Scale training across multiple GPUs and nodes