On-Device Training
The ONNX Runtime Training API enables training directly on edge devices, mobile platforms, and embedded systems. It provides a lightweight, cross-platform solution for federated learning, model personalization, and privacy-preserving on-device learning.Overview
Unlike ORTModule which wraps PyTorch models, the On-Device Training API works with pre-compiled ONNX models. This approach offers:- Minimal dependencies: No PyTorch or heavy ML framework required
- Small binary size: Optimized for resource-constrained devices
- Cross-platform: Works on iOS, Android, Linux, Windows, and embedded systems
- Fast startup: Pre-compiled models eliminate export overhead
- Privacy-preserving: Train models locally without sending data to the cloud
Key Concepts
Training Artifacts
The Training API requires four artifacts:- Training Model (
training_model.onnx): Base model + loss + gradient graph - Evaluation Model (
eval_model.onnx): Base model + loss (optional) - Optimizer Model (
optimizer.onnx): Optimizer update graph (optional) - Checkpoint (
checkpoint.ckpt): Model parameters and optimizer state
generate_artifacts utility.
Quick Start
Step 1: Generate Training Artifacts
First, export your PyTorch model to ONNX and generate training artifacts:Step 2: Training Loop
Use the generated artifacts for training:Complete Example
Here’s a complete example with a simple classifier:Advanced Features
Custom Loss Functions
Define custom loss functions using ONNXBlock:Nominal Checkpoints
For on-device applications, use nominal checkpoints to reduce package size:- Packaging models with mobile apps
- Parameters will be loaded from a separate source
- Reducing initial app download size
ORT Format
Convert models to ORT format for faster loading:Working with OrtValues
For better performance, use OrtValues instead of numpy arrays:Checkpoint Management
TheCheckpointState provides parameter access and management:
Supported Loss Functions
LossType.MSELoss: Mean squared error lossLossType.CrossEntropyLoss: Cross-entropy loss for classificationLossType.BCEWithLogitsLoss: Binary cross-entropy with logitsLossType.L1Loss: L1 (absolute error) loss
Supported Optimizers
OptimType.AdamW: Adam with weight decayOptimType.SGD: Stochastic gradient descent
Mobile and Edge Deployment
iOS Example
Android Example
Use Cases
Federated Learning
Train models across multiple devices without centralizing data:- Deploy initial model to all devices
- Each device trains locally
- Aggregate parameter updates on server
- Distribute updated model
Model Personalization
Adapt pre-trained models to individual users:- Ship pre-trained model with app
- Fine-tune on user’s device with their data
- Keep personalized model local
Edge AI Applications
Continuous learning on edge devices:- Deploy model to edge device (IoT, robotics)
- Collect local data
- Train incrementally
- Adapt to changing conditions
Performance Tips
- Use OrtValues: Avoid numpy conversion overhead
- Batch Processing: Process multiple samples together
- ORT Format: Use
.ortformat for faster loading - Quantization: Consider quantized models for mobile
- Memory Management: Reuse buffers when possible
Next Steps
ORTModule
For cloud-based PyTorch training
Training Overview
Explore all training options