Skip to main content

Model Optimization for Inference

Learn how to optimize ONNX models for production inference with graph optimization, quantization, profiling, and performance tuning techniques.

Overview

Model optimization is crucial for production deployment. ONNX Runtime provides multiple optimization strategies:
  • Graph Optimization: Fuse operators, eliminate redundant nodes, optimize memory layout
  • Quantization: Reduce model size and improve speed with reduced precision
  • Profiling: Identify performance bottlenecks
  • Memory Optimization: Reduce memory footprint and allocations
  • Threading: Optimize parallelism for multi-core processors

Graph Optimization

Graph optimization transforms the model computation graph for better performance.

Optimization Levels

ONNX Runtime provides four optimization levels: 1. Disabled (ORT_DISABLE_ALL)
  • No optimizations applied
  • Use for debugging or when optimizations cause issues
2. Basic (ORT_ENABLE_BASIC)
  • Constant folding
  • Redundant node elimination
  • Semantics-preserving node fusions
Faster session creation, moderate performance gains. 3. Extended (ORT_ENABLE_EXTENDED)
  • All basic optimizations
  • Complex node fusions (e.g., Conv + BatchNorm + Relu)
  • Node reordering
  • Algebraic simplifications
Balanced optimization for most use cases. 4. All (ORT_ENABLE_ALL)
  • All extended optimizations
  • Layout transformations (e.g., NCHWc format)
  • Advanced memory planning
Maximum performance, longer session creation time.

Applying Graph Optimization

import onnxruntime as ort

sess_options = ort.SessionOptions()

# Set optimization level
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

# Save optimized model to file
sess_options.optimized_model_filepath = "optimized_model.onnx"

session = ort.InferenceSession("model.onnx", sess_options)

Common Graph Optimizations

Operator Fusion:
  • Conv + BatchNorm + Relu → FusedConv
  • MatMul + Add → Gemm
  • Multiple Transpose operations → Single Transpose
Constant Folding:
  • Pre-compute constant operations at graph load time
  • Reduces inference computation
Dead Code Elimination:
  • Remove unused nodes and outputs
  • Reduces memory and computation
Layout Optimization:
  • Convert NCHW → NCHWc (channels-last format)
  • Better cache locality and vectorization

Quantization

Quantization reduces model size and improves inference speed by using lower precision (INT8) instead of FP32.

Dynamic Quantization

Weights are quantized offline, activations are quantized dynamically during inference.
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType

# Quantize model
model_input = "model.onnx"
model_output = "model_quantized.onnx"

quantize_dynamic(
    model_input,
    model_output,
    weight_type=QuantType.QInt8
)

print("Model quantized successfully")

# Use quantized model
session = ort.InferenceSession(model_output)
Benefits:
  • 4x model size reduction
  • 2-4x inference speedup on CPU
  • Minimal accuracy loss (< 1%)
  • No calibration data required

Static Quantization (QDQ)

Both weights and activations are quantized using calibration data.
from onnxruntime.quantization import quantize_static, CalibrationDataReader
import numpy as np

class DataReader(CalibrationDataReader):
    def __init__(self, calibration_data):
        self.data = calibration_data
        self.iterator = iter(self.data)
    
    def get_next(self):
        try:
            return next(self.iterator)
        except StopIteration:
            return None
    
    def rewind(self):
        self.iterator = iter(self.data)

# Generate calibration data
calibration_data = []
for i in range(100):
    input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
    calibration_data.append({"input": input_data})

data_reader = DataReader(calibration_data)

# Quantize model
quantize_static(
    model_input="model.onnx",
    model_output="model_static_quant.onnx",
    calibration_data_reader=data_reader,
    quant_format=QuantFormat.QDQ
)
Benefits:
  • Better accuracy than dynamic quantization
  • Faster inference than dynamic quantization
  • Requires calibration dataset

Quantization Guidelines

  • Use dynamic quantization for quick deployment with minimal setup
  • Use static quantization for maximum performance when calibration data is available
Always evaluate quantized model accuracy on your validation set:
# Compare FP32 vs INT8
fp32_session = ort.InferenceSession("model.onnx")
int8_session = ort.InferenceSession("model_quantized.onnx")

# Run both and compare outputs
fp32_output = fp32_session.run(None, inputs)
int8_output = int8_session.run(None, inputs)
Some operators may not be quantized. The quantization tool will skip unsupported operators automatically.
  • x86 CPUs: Use VNNI or AVX512 for INT8 acceleration
  • ARM CPUs: Use NEON instructions
  • GPUs: Limited INT8 support, check execution provider documentation

Profiling

Profile model execution to identify bottlenecks.

Enable Profiling

import onnxruntime as ort

sess_options = ort.SessionOptions()
sess_options.enable_profiling = True
sess_options.profile_file_prefix = "ort_profile"

session = ort.InferenceSession("model.onnx", sess_options)

# Run inference
outputs = session.run(None, inputs)

# Profiling file saved: ort_profile_<timestamp>.json

Analyze Profiling Results

The profiling output is a JSON file with Chrome Tracing format. View in Chrome:
  1. Open Chrome browser
  2. Navigate to chrome://tracing
  3. Click “Load” and select the profiling JSON file
Analyze with Python:
import json
import pandas as pd

# Load profiling data
with open('ort_profile_timestamp.json', 'r') as f:
    profile = json.load(f)

# Extract operator timings
events = profile['traceEvents']
op_times = {}

for event in events:
    if event.get('cat') == 'Node' and 'dur' in event:
        name = event['name']
        duration = event['dur'] / 1000  # Convert to milliseconds
        
        if name not in op_times:
            op_times[name] = []
        op_times[name].append(duration)

# Calculate statistics
df = pd.DataFrame([
    {
        'operator': name,
        'count': len(times),
        'total_ms': sum(times),
        'avg_ms': sum(times) / len(times),
        'max_ms': max(times)
    }
    for name, times in op_times.items()
])

# Sort by total time
df = df.sort_values('total_ms', ascending=False)
print("Top 10 operators by total time:")
print(df.head(10))

Profiling Metrics

  • Kernel Time: Time spent executing each operator
  • Memory Allocation: Memory allocation events
  • Data Transfer: CPU-GPU data transfer time (if using GPU)
  • Session Overhead: Session initialization and cleanup

Memory Optimization

Memory Arena

Enable memory arena for efficient memory allocation:
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = True  # Enable CPU memory arena
Benefits:
  • Reduces memory fragmentation
  • Faster allocation/deallocation
  • Lower peak memory usage

Memory Pattern Optimization

sess_options.enable_mem_pattern = True
ONNX Runtime analyzes memory usage patterns and pre-allocates memory for better performance.

Sequential Execution Mode

For memory-constrained environments:
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
Reduces peak memory usage by executing operators sequentially instead of in parallel.

Threading Optimization

Intra-Op Threading

Parallelism within a single operator (e.g., matrix multiplication):
sess_options.intra_op_num_threads = 4
Guidelines:
  • Set to number of physical cores for CPU-bound operations
  • More threads ≠ always faster (overhead increases)
  • Start with physical core count and tune based on profiling

Inter-Op Threading

Parallelism between independent operators:
sess_options.inter_op_num_threads = 2
Guidelines:
  • Useful for models with parallel branches
  • Usually set to 1 or 2
  • Higher values can cause overhead

Threading Best Practices

import onnxruntime as ort
import multiprocessing

physical_cores = multiprocessing.cpu_count() // 2

sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = physical_cores
sess_options.inter_op_num_threads = 1
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL

Execution Provider Optimization

CPU Optimization

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4

# Use optimized CPU kernels
session = ort.InferenceSession(
    "model.onnx",
    sess_options,
    providers=['CPUExecutionProvider']
)

GPU Optimization (CUDA)

cuda_options = {
    'device_id': 0,
    'arena_extend_strategy': 'kSameAsRequested',
    'gpu_mem_limit': 2 * 1024 * 1024 * 1024,  # 2GB
    'cudnn_conv_algo_search': 'EXHAUSTIVE',  # Find best conv algorithm
}

session = ort.InferenceSession(
    "model.onnx",
    sess_options,
    providers=[('CUDAExecutionProvider', cuda_options)]
)

TensorRT Optimization

trt_options = {
    'device_id': 0,
    'trt_max_workspace_size': 2147483648,  # 2GB
    'trt_fp16_enable': True,  # Enable FP16
    'trt_int8_enable': True,  # Enable INT8 (requires calibration)
    'trt_engine_cache_enable': True,  # Cache TensorRT engines
    'trt_engine_cache_path': './trt_cache'
}

session = ort.InferenceSession(
    "model.onnx",
    providers=[('TensorrtExecutionProvider', trt_options)]
)

Model Size Optimization

External Data Format

For large models, store weights externally:
import onnx

# Load model
model = onnx.load("large_model.onnx")

# Save with external data
onnx.save_model(
    model,
    "large_model_external.onnx",
    save_as_external_data=True,
    all_tensors_to_one_file=True,
    location="weights.bin",
    size_threshold=1024  # Save tensors > 1KB externally
)

Model Pruning

Remove unnecessary outputs:
import onnx
from onnx import helper

model = onnx.load("model.onnx")

# Keep only specific outputs
outputs_to_keep = ["output1"]
model.graph.ClearField("output")
for output_name in outputs_to_keep:
    for node in model.graph.node:
        for output in node.output:
            if output == output_name:
                model.graph.output.append(
                    helper.make_tensor_value_info(
                        output_name,
                        onnx.TensorProto.FLOAT,
                        None
                    )
                )

onnx.save(model, "model_pruned.onnx")

Batching Strategies

Static Batching

# Process multiple inputs in a single batch
batch_size = 8
input_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)

outputs = session.run(None, {"input": input_data})

Dynamic Batching

class BatchedInference:
    def __init__(self, session, max_batch_size=32, timeout_ms=10):
        self.session = session
        self.max_batch_size = max_batch_size
        self.timeout_ms = timeout_ms
        self.queue = []
    
    def predict(self, input_data):
        self.queue.append(input_data)
        
        if len(self.queue) >= self.max_batch_size:
            return self._run_batch()
        
        # Wait for more requests or timeout
        # Implementation depends on your framework
    
    def _run_batch(self):
        batch = np.stack(self.queue)
        self.queue.clear()
        
        outputs = self.session.run(None, {"input": batch})
        return outputs

Benchmarking

Performance Measurement

import time
import numpy as np

def benchmark_model(session, num_runs=100, warmup_runs=10):
    input_name = session.get_inputs()[0].name
    input_shape = session.get_inputs()[0].shape
    input_data = np.random.randn(*input_shape).astype(np.float32)
    
    # Warmup
    for _ in range(warmup_runs):
        session.run(None, {input_name: input_data})
    
    # Benchmark
    times = []
    for _ in range(num_runs):
        start = time.perf_counter()
        session.run(None, {input_name: input_data})
        end = time.perf_counter()
        times.append((end - start) * 1000)  # Convert to ms
    
    times = np.array(times)
    print(f"Mean: {times.mean():.2f} ms")
    print(f"Median: {np.median(times):.2f} ms")
    print(f"Std: {times.std():.2f} ms")
    print(f"Min: {times.min():.2f} ms")
    print(f"Max: {times.max():.2f} ms")
    print(f"P95: {np.percentile(times, 95):.2f} ms")
    print(f"P99: {np.percentile(times, 99):.2f} ms")

# Run benchmark
benchmark_model(session)

Optimization Checklist

1

Enable Graph Optimization

Set graph_optimization_level to ORT_ENABLE_ALL
2

Choose Execution Provider

Use CUDA/TensorRT for NVIDIA GPUs, DirectML for Windows, CoreML for Apple devices
3

Configure Threading

Set intra_op_num_threads to physical core count, inter_op_num_threads to 1-2
4

Enable Memory Optimization

Enable cpu_mem_arena and mem_pattern for better memory management
5

Consider Quantization

Use dynamic or static quantization for CPU inference
6

Profile Performance

Enable profiling to identify bottlenecks
7

Benchmark

Measure performance with realistic inputs and compare configurations

Next Steps

Execution Providers

Learn about hardware-specific optimizations

Python API

Return to Python inference guide