Skip to main content

Converting PyTorch Models to ONNX

PyTorch provides native support for exporting models to ONNX format through the torch.onnx.export() function. This guide covers the conversion process with practical examples.

Prerequisites

pip install torch onnx onnxruntime

Basic Conversion

Simple Model Export

Here’s a basic example of exporting a PyTorch model to ONNX:
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 5)
    
    def forward(self, x):
        return self.linear(x)

# Create model instance
model = SimpleModel()
model.eval()

# Create dummy input
dummy_input = torch.randn(1, 10)

# Export to ONNX
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=14,
    do_constant_folding=True
)

Advanced Export with Dynamic Axes

For models that need to handle variable input sizes (e.g., different batch sizes or sequence lengths), use dynamic axes:
import torch
from transformers import AutoModel, AutoTokenizer

# Load pre-trained model
model_name = "bert-base-uncased"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.eval()

# Prepare example inputs
text = "This is a sample input"
example_inputs = tokenizer(text, return_tensors="pt")

# Define dynamic axes for variable batch size and sequence length
dynamic_axes = {
    "input_ids": {0: "batch_size", 1: "seq_len"},
    "attention_mask": {0: "batch_size", 1: "seq_len"},
    "output": {0: "batch_size", 1: "seq_len"}
}

# Export with dynamic axes
torch.onnx.export(
    model,
    args=tuple(example_inputs.values()),
    f="bert_model.onnx",
    input_names=list(example_inputs.keys()),
    output_names=["last_hidden_state", "pooler_output"],
    dynamic_axes=dynamic_axes,
    opset_version=14,
    do_constant_folding=True
)

ONNX Runtime Export Helper

ONNX Runtime provides a helper function for PyTorch export with additional compatibility options:
from torch._C._onnx import OperatorExportTypes
import torch

def torch_onnx_export(
    model,
    args,
    f,
    export_params=True,
    verbose=False,
    training=torch.onnx.TrainingMode.EVAL,
    input_names=None,
    output_names=None,
    operator_export_type=OperatorExportTypes.ONNX,
    opset_version=14,
    do_constant_folding=True,
    dynamic_axes=None,
    keep_initializers_as_inputs=None,
    custom_opsets=None,
    export_modules_as_functions=False,
):
    torch.onnx.export(
        model=model,
        args=args,
        f=f,
        export_params=export_params,
        verbose=verbose,
        training=training,
        input_names=input_names,
        output_names=output_names,
        operator_export_type=operator_export_type,
        opset_version=opset_version,
        do_constant_folding=do_constant_folding,
        dynamic_axes=dynamic_axes,
        keep_initializers_as_inputs=keep_initializers_as_inputs,
        custom_opsets=custom_opsets,
        export_modules_as_functions=export_modules_as_functions,
        dynamo=False,
    )

# Usage
model = YourModel()
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch_onnx_export(
    model=model,
    args=(dummy_input,),
    f="model.onnx",
    input_names=["image"],
    output_names=["output"],
    opset_version=14
)

Exporting Vision Transformers

Example for exporting Vision Transformer (ViT) models:
import torch
import numpy as np
from transformers import AutoFeatureExtractor, AutoModel

model_name = "google/vit-base-patch16-224"
model = AutoModel.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

model.eval()

# Prepare example image input
image_size = 224
data = np.random.randint(
    low=0, high=256, 
    size=image_size * image_size * 3, 
    dtype=np.uint8
).reshape(image_size, image_size, 3)

example_inputs = feature_extractor(data, return_tensors="pt")

# Export with dynamic batch size
dynamic_axes = {
    "pixel_values": {0: "batch_size"}
}

torch.onnx.export(
    model,
    args=tuple(example_inputs.values()),
    f="vit_model.onnx",
    input_names=["pixel_values"],
    output_names=["last_hidden_state"],
    dynamic_axes=dynamic_axes,
    opset_version=14,
    do_constant_folding=True
)

Handling Large Models

For models larger than 2GB, use external data format:
import torch

model = LargeModel()
model.eval()

dummy_input = torch.randn(1, 512)

# Export with external data format
torch.onnx.export(
    model,
    dummy_input,
    "large_model.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=14,
    use_external_data_format=True  # Store weights in separate file
)

Validating the Exported Model

Always validate your ONNX model after export:
import onnx
import onnxruntime as ort
import numpy as np

# Load and check the ONNX model
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid")

# Run inference with ONNX Runtime
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Prepare input
input_data = np.random.randn(1, 10).astype(np.float32)

# Run inference
result = session.run([output_name], {input_name: input_data})
print(f"Output shape: {result[0].shape}")

Common Issues and Solutions

Issue: Unsupported Operations

Some PyTorch operations may not have ONNX equivalents. Replace them with ONNX-compatible alternatives:
# Problematic: torch.triu may not export properly
def triu_onnx(x, diagonal=0, out=None):
    assert out is None
    assert len(x.shape) == 2 and x.size(0) == x.size(1)
    
    # Create template mask
    template = torch.triu(torch.ones((1024, 1024), dtype=torch.uint8), diagonal)
    mask = template[:x.size(0), :x.size(1)]
    return torch.where(mask.bool(), x, torch.zeros_like(x))

# Replace torch.triu temporarily during export
torch_triu = torch.triu
torch.triu = triu_onnx

# Export model
torch.onnx.export(model, dummy_input, "model.onnx")

# Restore original function
torch.triu = torch_triu

Issue: Dynamic Control Flow

Avoid dynamic control flow (if/else based on input values). Use static shapes or ONNX operators instead.

Best Practices

  1. Always set model to eval mode: model.eval() before export
  2. Use appropriate opset version: Version 14+ is recommended for most models
  3. Enable constant folding: Set do_constant_folding=True for optimization
  4. Provide meaningful names: Use descriptive input_names and output_names
  5. Test with real inputs: Validate exported model with actual data
  6. Check for warnings: Review export warnings and address compatibility issues

Next Steps