Skip to main content

Overview

torch.export provides APIs for exporting PyTorch models to a graph-based representation that can be deployed, optimized, and serialized. It produces a traced graph representing only the Tensor computation in an Ahead-of-Time (AOT) fashion. The exported graph:
  1. Produces normalized operators in the functional ATen operator set
  2. Eliminates Python control flow and data structures
  3. Records shape constraints for soundness verification

Core Functions

Export a PyTorch model to an ExportedProgram.
mod
torch.nn.Module
required
The module to export.
args
tuple
required
Example positional inputs to the module.
kwargs
dict
Example keyword arguments to the module.
dynamic_shapes
dict | tuple | list
Specification of dynamic dimensions using the Dim API.
strict
bool
default:"False"
If True, enforces stricter rules during export.
preserve_module_call_signature
tuple
Module paths whose call signatures should be preserved.
Returns: ExportedProgram - The exported program that can be serialized or optimized.
import torch
from torch.export import export

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 5)
    
    def forward(self, x):
        return self.linear(x)

model = MyModel()
example_input = torch.randn(2, 10)

# Export the model
exported_program = export(model, (example_input,))

# Run the exported program
output = exported_program(torch.randn(2, 10))
Save an ExportedProgram to a file.
ep
ExportedProgram
required
The exported program to save.
f
str | Path | io.BytesIO
required
File path or file-like object to save to.
from torch.export import export, save

# After exporting
exported_program = export(model, (example_input,))

# Save to file
save(exported_program, "model.pt")
Load an ExportedProgram from a file.
f
str | Path | io.BytesIO
required
File path or file-like object to load from.
Returns: ExportedProgram - The loaded exported program.
from torch.export import load

# Load from file
loaded_program = load("model.pt")

# Use the loaded program
output = loaded_program(torch.randn(2, 10))
Unflatten an exported program back to a module-like structure.
module
ExportedProgram
required
The exported program to unflatten.
Returns: UnflattenedModule - Module with the original structure restored.
from torch.export import export, unflatten

exported_program = export(model, (example_input,))

# Unflatten to restore module structure
unflattened = unflatten(exported_program)

Dynamic Shapes

Dim API

Use the Dim API to specify dynamic dimensions in input tensors.
Create a dynamic dimension specification.
name
str
required
Name of the dynamic dimension.
min
int
Minimum value for the dimension (inclusive).
max
int
Maximum value for the dimension (inclusive).
from torch.export import export, Dim

class Model(torch.nn.Module):
    def forward(self, x):
        return x.sum(dim=1)

model = Model()
x = torch.randn(3, 5)

# Define dynamic batch dimension
batch = Dim("batch", min=1, max=128)

# Export with dynamic shapes
exported = export(
    model,
    (x,),
    dynamic_shapes={"x": {0: batch}}
)

# Can now run with different batch sizes
output1 = exported(torch.randn(1, 5))   # batch=1
output2 = exported(torch.randn(64, 5))  # batch=64
Create multiple named dynamic dimensions at once.
from torch.export import export, dims

# Create multiple dimensions
batch, seq_len = dims("batch", "seq_len")

# Use in dynamic_shapes
exported = export(
    model,
    (x,),
    dynamic_shapes={"x": {0: batch, 1: seq_len}}
)

ExportedProgram

The ExportedProgram class represents an exported PyTorch model.

Methods

module()
callable
Call the exported program with new inputs.
graph_module
torch.fx.GraphModule
The underlying FX graph module.
graph_signature
ExportGraphSignature
Signature describing inputs, outputs, and state.
module_call_graph
list
Preserved module call hierarchy.

Constraints

Use constraints to specify relationships between dynamic dimensions.
from torch.export import export, Dim

batch = Dim("batch", min=1, max=64)
seq_len = Dim("seq_len", min=1, max=512)

# Constrain seq_len to be a multiple of 4
seq_len = seq_len // 4 * 4

exported = export(
    model,
    (x, y),
    dynamic_shapes={
        "x": {0: batch, 1: seq_len},
        "y": {0: batch, 1: seq_len}
    }
)

Common Patterns

Export with Multiple Inputs

class MultiInputModel(torch.nn.Module):
    def forward(self, x, y, z=None):
        if z is not None:
            return x + y + z
        return x + y

model = MultiInputModel()

# Export with args and kwargs
exported = export(
    model,
    args=(torch.randn(2, 3), torch.randn(2, 3)),
    kwargs={"z": torch.randn(2, 3)}
)

Export with State Dict

# Export preserves model parameters
exported = export(model, (example_input,))

# State dict is included
print(exported.state_dict.keys())

# Can modify and reload
state = exported.state_dict
state['linear.weight'] = new_weights

Deployment Workflow

from torch.export import export, save, load

# 1. Training: Export the model
model = train_model()
exported = export(model, (example_input,))

# 2. Save for deployment
save(exported, "model_v1.pt")

# 3. Deployment: Load and run
deployed = load("model_v1.pt")
output = deployed(production_input)

Best Practices

Provide example inputs that represent the full range of shapes and values your model will encounter in production.
For inputs with variable shapes, always specify dynamic dimensions using the Dim API to avoid re-exporting for different sizes.
Validate that the exported program produces the same outputs as the original model with various inputs.
# Compare outputs
original_out = model(test_input)
exported_out = exported_program(test_input)
assert torch.allclose(original_out, exported_out)
torch.export may not support all Python control flow. Use torch.cond for conditionals in exported code.
# Instead of Python if
def forward(self, x, condition):
    # Use torch.cond for exportable control flow
    return torch.cond(condition, lambda: x + 1, lambda: x - 1)

torch.jit

TorchScript JIT compilation (alternative export approach)

torch.compile

Just-in-time compilation for performance

ONNX Export

Export to ONNX format for broader deployment

Quantization

Optimize exported models with quantization

Build docs developers (and LLMs) love