Skip to main content

On-Device Training

The ONNX Runtime Training API enables training directly on edge devices, mobile platforms, and embedded systems. It provides a lightweight, cross-platform solution for federated learning, model personalization, and privacy-preserving on-device learning.

Overview

Unlike ORTModule which wraps PyTorch models, the On-Device Training API works with pre-compiled ONNX models. This approach offers:
  • Minimal dependencies: No PyTorch or heavy ML framework required
  • Small binary size: Optimized for resource-constrained devices
  • Cross-platform: Works on iOS, Android, Linux, Windows, and embedded systems
  • Fast startup: Pre-compiled models eliminate export overhead
  • Privacy-preserving: Train models locally without sending data to the cloud

Key Concepts

Training Artifacts

The Training API requires four artifacts:
  1. Training Model (training_model.onnx): Base model + loss + gradient graph
  2. Evaluation Model (eval_model.onnx): Base model + loss (optional)
  3. Optimizer Model (optimizer.onnx): Optimizer update graph (optional)
  4. Checkpoint (checkpoint.ckpt): Model parameters and optimizer state
These artifacts are generated offline using the generate_artifacts utility.

Quick Start

Step 1: Generate Training Artifacts

First, export your PyTorch model to ONNX and generate training artifacts:
import torch
import onnx
from onnxruntime.training import artifacts

# Export your PyTorch model
model = build_your_model()
sample_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    sample_input,
    "base_model.onnx",
    export_params=True,
    training=torch.onnx.TrainingMode.TRAINING,
    do_constant_folding=False
)

# Load the base model
base_model = onnx.load("base_model.onnx")

# Define which parameters require gradients
requires_grad = ["conv1.weight", "conv1.bias", "fc.weight", "fc.bias"]
frozen_params = ["conv2.weight", "conv2.bias"]  # Optional: freeze some layers

# Generate training artifacts
artifacts.generate_artifacts(
    base_model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    loss=artifacts.LossType.CrossEntropyLoss,
    optimizer=artifacts.OptimType.AdamW
)

# This generates:
# - training_model.onnx
# - eval_model.onnx  
# - optimizer.onnx
# - checkpoint (directory)

Step 2: Training Loop

Use the generated artifacts for training:
import numpy as np
from onnxruntime.training.api import Module, Optimizer, CheckpointState

# Load checkpoint
state = CheckpointState.load_checkpoint("checkpoint.ckpt")

# Create training module
model = Module(
    "training_model.onnx",
    state,
    eval_model_uri="eval_model.onnx",
    device="cuda"  # or "cpu"
)

# Create optimizer
optimizer = Optimizer("optimizer.onnx", model)

# Training loop
model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        # Prepare inputs (as numpy arrays)
        input_data = batch['data'].numpy()
        labels = batch['labels'].numpy()
        
        # Forward pass (returns loss and other outputs)
        outputs = model(input_data, labels)
        loss = outputs[0]
        
        # Backward pass and optimizer step
        optimizer.step()
        
        print(f"Epoch {epoch}, Loss: {loss}")

# Evaluation
model.eval()
for batch in val_dataloader:
    input_data = batch['data'].numpy()
    labels = batch['labels'].numpy()
    
    outputs = model(input_data, labels)
    # Process evaluation outputs

# Save checkpoint
CheckpointState.save_checkpoint(state, "checkpoint_final.ckpt")

Complete Example

Here’s a complete example with a simple classifier:
import torch
import torch.nn as nn
import onnx
import numpy as np
from onnxruntime.training import artifacts
from onnxruntime.training.api import Module, Optimizer, CheckpointState

# Step 1: Define and export PyTorch model
class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Create and export model
pt_model = SimpleClassifier()
sample_input = torch.randn(1, 784)

torch.onnx.export(
    pt_model,
    sample_input,
    "classifier.onnx",
    export_params=True,
    training=torch.onnx.TrainingMode.TRAINING,
    do_constant_folding=False,
    input_names=['input'],
    output_names=['output']
)

# Step 2: Generate training artifacts
base_model = onnx.load("classifier.onnx")

artifacts.generate_artifacts(
    base_model,
    requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
    loss=artifacts.LossType.CrossEntropyLoss,
    optimizer=artifacts.OptimType.AdamW,
    artifact_directory="./training_artifacts"
)

# Step 3: Train the model
state = CheckpointState.load_checkpoint(
    "./training_artifacts/checkpoint"
)

model = Module(
    "./training_artifacts/training_model.onnx",
    state,
    eval_model_uri="./training_artifacts/eval_model.onnx",
    device="cpu"
)

optimizer = Optimizer(
    "./training_artifacts/optimizer.onnx",
    model
)

# Generate dummy training data
def generate_batch(batch_size=32):
    X = np.random.randn(batch_size, 784).astype(np.float32)
    y = np.random.randint(0, 10, size=batch_size).astype(np.int64)
    return X, y

# Training loop
model.train()
for epoch in range(5):
    epoch_loss = 0.0
    num_batches = 100
    
    for batch_idx in range(num_batches):
        X, y = generate_batch()
        
        # Forward and backward pass
        outputs = model(X, y)
        loss = outputs[0]  # First output is the loss
        
        # Update parameters
        optimizer.step()
        
        epoch_loss += loss
    
    avg_loss = epoch_loss / num_batches
    print(f"Epoch {epoch + 1}, Average Loss: {avg_loss:.4f}")

# Save final checkpoint
CheckpointState.save_checkpoint(
    state,
    "./training_artifacts/checkpoint_final.ckpt"
)

print("Training completed!")

Advanced Features

Custom Loss Functions

Define custom loss functions using ONNXBlock:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training import artifacts

class WeightedAverageLoss(onnxblock.Block):
    def __init__(self):
        self._loss1 = onnxblock.loss.MSELoss()
        self._loss2 = onnxblock.loss.MSELoss()
        self._w1 = onnxblock.blocks.Constant(0.4)
        self._w2 = onnxblock.blocks.Constant(0.6)
        self._add = onnxblock.blocks.Add()
        self._mul = onnxblock.blocks.Mul()
    
    def build(self, loss_input_1, loss_input_2):
        return self._add(
            self._mul(self._w1(), self._loss1(loss_input_1, target_name="target1")),
            self._mul(self._w2(), self._loss2(loss_input_2, target_name="target2"))
        )

# Use custom loss
custom_loss = WeightedAverageLoss()
artifacts.generate_artifacts(
    base_model,
    requires_grad=requires_grad,
    loss=custom_loss,
    optimizer=artifacts.OptimType.AdamW
)

Nominal Checkpoints

For on-device applications, use nominal checkpoints to reduce package size:
artifacts.generate_artifacts(
    base_model,
    requires_grad=requires_grad,
    loss=artifacts.LossType.CrossEntropyLoss,
    optimizer=artifacts.OptimType.AdamW,
    nominal_checkpoint=True  # Generate lightweight checkpoint
)
Nominal checkpoints contain only parameter metadata, not actual values. They’re useful when:
  • Packaging models with mobile apps
  • Parameters will be loaded from a separate source
  • Reducing initial app download size

ORT Format

Convert models to ORT format for faster loading:
artifacts.generate_artifacts(
    base_model,
    requires_grad=requires_grad,
    loss=artifacts.LossType.CrossEntropyLoss,
    optimizer=artifacts.OptimType.AdamW,
    ort_format=True  # Generate .ort files instead of .onnx
)

Working with OrtValues

For better performance, use OrtValues instead of numpy arrays:
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue

# Create OrtValue from numpy
input_data = np.random.randn(32, 784).astype(np.float32)
ort_input = OrtValue.ortvalue_from_numpy(input_data)

# Pass to model
outputs = model(ort_input, labels)

Checkpoint Management

The CheckpointState provides parameter access and management:
from onnxruntime.training.api import CheckpointState

# Load checkpoint
state = CheckpointState.load_checkpoint("checkpoint.ckpt")

# Access parameters
for param_name in state.parameters:
    param = state.parameters[param_name]
    print(f"{param.name}: shape={param.data.shape}, requires_grad={param.requires_grad}")
    
    # Modify parameter
    if param.name == "fc1.weight":
        param.data = new_weights  # Update weights
        
    # Access gradients
    if param.grad is not None:
        print(f"Gradient: {param.grad}")

# Save modified checkpoint
CheckpointState.save_checkpoint(state, "checkpoint_modified.ckpt")

Supported Loss Functions

  • LossType.MSELoss: Mean squared error loss
  • LossType.CrossEntropyLoss: Cross-entropy loss for classification
  • LossType.BCEWithLogitsLoss: Binary cross-entropy with logits
  • LossType.L1Loss: L1 (absolute error) loss

Supported Optimizers

  • OptimType.AdamW: Adam with weight decay
  • OptimType.SGD: Stochastic gradient descent

Mobile and Edge Deployment

iOS Example

import onnxruntime_training

// Load checkpoint
let state = try CheckpointState.loadCheckpoint("checkpoint.ckpt")

// Create module
let model = try Module(
    trainModelUri: "training_model.onnx",
    state: state,
    device: "cpu"
)

// Training loop
model.setTrainingMode(true)
for epoch in 0..<numEpochs {
    let outputs = try model.call(inputs: inputs)
    try optimizer.step()
}

Android Example

import ai.onnxruntime.training.*

// Load checkpoint
val state = CheckpointState.loadCheckpoint("checkpoint.ckpt")

// Create module
val model = Module(
    trainModelUri = "training_model.onnx",
    state = state,
    device = "cpu"
)

// Training loop
model.train()
for (epoch in 0 until numEpochs) {
    val outputs = model(inputs)
    optimizer.step()
}

Use Cases

Federated Learning

Train models across multiple devices without centralizing data:
  1. Deploy initial model to all devices
  2. Each device trains locally
  3. Aggregate parameter updates on server
  4. Distribute updated model

Model Personalization

Adapt pre-trained models to individual users:
  1. Ship pre-trained model with app
  2. Fine-tune on user’s device with their data
  3. Keep personalized model local

Edge AI Applications

Continuous learning on edge devices:
  1. Deploy model to edge device (IoT, robotics)
  2. Collect local data
  3. Train incrementally
  4. Adapt to changing conditions

Performance Tips

  1. Use OrtValues: Avoid numpy conversion overhead
  2. Batch Processing: Process multiple samples together
  3. ORT Format: Use .ort format for faster loading
  4. Quantization: Consider quantized models for mobile
  5. Memory Management: Reuse buffers when possible

Next Steps

ORTModule

For cloud-based PyTorch training

Training Overview

Explore all training options