MLX provides functions to export compiled computation graphs to files and import them back. This is useful for:
- Deploying optimized models
- Sharing pre-compiled functions
- Debugging computation graphs
- Understanding how MLX transforms your code
Overview
Exported functions preserve the compiled and optimized computation graph, which can be imported and executed later without recompilation.
Functions
export_function
mlx.core.export_function(
fun: callable,
*args,
file: str
) -> None
Export a compiled MLX function to a file.
The function is traced with the provided arguments and the resulting computation graph is saved.
Parameters:
fun (callable): Function to export
*args: Example arguments used to trace the function
file (str): Path to save the exported function
Example:
import mlx.core as mx
def matmul_and_add(A, B, C):
return (A @ B) + C
# Trace and export the function
A = mx.random.normal((100, 100))
B = mx.random.normal((100, 100))
C = mx.random.normal((100, 100))
mx.export_function(
matmul_and_add,
A, B, C,
file="matmul_add.mlx"
)
print("Function exported to matmul_add.mlx")
With function transforms:
import mlx.core as mx
import mlx.nn as nn
def loss_fn(model, x, y):
return mx.mean((model(x) - y) ** 2)
# Export the compiled gradient function
model = create_model()
x = mx.random.normal((32, 10))
y = mx.random.normal((32, 1))
loss_and_grad = mx.value_and_grad(loss_fn)
mx.export_function(
loss_and_grad,
model, x, y,
file="loss_and_grad.mlx"
)
The exported function is tied to the specific input shapes and dtypes used during export. Attempting to import and call with different shapes will result in an error.
import_function
mlx.core.import_function(file: str) -> callable
Import a previously exported MLX function.
Returns a callable that executes the pre-compiled computation graph.
Parameters:
file (str): Path to the exported function file
Returns:
- Callable function that executes the imported graph
Example:
import mlx.core as mx
# Import the previously exported function
matmul_and_add = mx.import_function("matmul_add.mlx")
# Use it with matching shapes
A = mx.random.normal((100, 100))
B = mx.random.normal((100, 100))
C = mx.random.normal((100, 100))
result = matmul_and_add(A, B, C)
print(result.shape) # (100, 100)
Deployment example:
import mlx.core as mx
# In production, just import the pre-compiled function
predict = mx.import_function("model_inference.mlx")
# Run inference without recompilation
for batch in data_loader:
predictions = predict(batch)
process_predictions(predictions)
exporter
@mlx.core.exporter(file: str)
def function(*args):
...
Decorator to automatically export a function when it’s first called.
The function is traced with the first set of arguments it receives and exported to the specified file.
Parameters:
file (str): Path to save the exported function
Example:
import mlx.core as mx
@mx.exporter("my_function.mlx")
def complex_computation(x, y):
z = mx.exp(x) + mx.log(y)
return mx.sin(z) * mx.cos(z)
# First call traces and exports
x = mx.array([1.0, 2.0, 3.0])
y = mx.array([4.0, 5.0, 6.0])
result = complex_computation(x, y)
# Function is now exported to my_function.mlx
With neural network layers:
import mlx.core as mx
import mlx.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = [
nn.Linear(784, 256),
nn.Linear(256, 10)
]
@mx.exporter("model_forward.mlx")
def __call__(self, x):
for layer in self.layers:
x = mx.relu(layer(x))
return x
model = MyModel()
x = mx.random.normal((32, 784))
output = model(x) # Exports on first call
export_to_dot
mlx.core.export_to_dot(
fun: callable,
*args
) -> str
Export the computation graph as a DOT format string.
Useful for visualizing and debugging MLX computation graphs.
Parameters:
fun (callable): Function to trace
*args: Arguments used to trace the function
Returns:
- String containing the graph in DOT format
Example:
import mlx.core as mx
def example_fn(x, y):
a = x + y
b = a * 2
return mx.sum(b)
x = mx.array([1.0, 2.0, 3.0])
y = mx.array([4.0, 5.0, 6.0])
# Get DOT representation
dot_string = mx.export_to_dot(example_fn, x, y)
print(dot_string)
Visualizing the graph:
import mlx.core as mx
import subprocess
def visualize_graph(fun, *args, output_file="graph.png"):
dot_string = mx.export_to_dot(fun, *args)
# Save to file and render with graphviz
with open("graph.dot", "w") as f:
f.write(dot_string)
subprocess.run([
"dot", "-Tpng", "graph.dot",
"-o", output_file
])
print(f"Graph saved to {output_file}")
# Visualize a complex function
def model_forward(x, w1, w2):
h = mx.tanh(x @ w1)
return h @ w2
x = mx.random.normal((10, 5))
w1 = mx.random.normal((5, 8))
w2 = mx.random.normal((8, 3))
visualize_graph(model_forward, x, w1, w2)
Debugging compilation:
import mlx.core as mx
@mx.compile
def optimized_fn(x):
y = x * 2
z = y + 3
return mx.sum(z)
# See what the compiled graph looks like
x = mx.array([1.0, 2.0, 3.0])
dot = mx.export_to_dot(optimized_fn, x)
# Check for fusion opportunities
if "fused" in dot.lower():
print("Operations were fused!")
else:
print("No fusion detected")
Use Cases
Model Deployment
Export a trained model for efficient deployment:
import mlx.core as mx
import mlx.nn as nn
# Train your model
model = train_model()
# Export inference function
def inference(x):
return model(x)
# Export with example input shape
example_input = mx.zeros((1, 224, 224, 3))
mx.export_function(
inference,
example_input,
file="model_inference.mlx"
)
# In production
infer = mx.import_function("model_inference.mlx")
result = infer(input_batch)
Visualize what operations are being performed:
import mlx.core as mx
def slow_function(x):
# Multiple small operations
for _ in range(100):
x = x + 1
return x
def fast_function(x):
# Single fused operation
return x + 100
# Compare graphs
x = mx.array([1.0])
print("Slow version:")
print(mx.export_to_dot(slow_function, x))
print("\nFast version:")
print(mx.export_to_dot(fast_function, x))
Sharing Compiled Functions
Share optimized functions with team members:
import mlx.core as mx
# Data scientist exports optimized preprocessing
def preprocess(images):
images = images / 255.0
images = images - mx.array([0.485, 0.456, 0.406])
images = images / mx.array([0.229, 0.224, 0.225])
return images
example = mx.zeros((1, 224, 224, 3))
mx.export_function(
preprocess,
example,
file="preprocessing.mlx"
)
# Engineer imports and uses it
preprocess = mx.import_function("preprocessing.mlx")
processed_batch = preprocess(raw_batch)
Limitations
Important limitations when using export/import:
- Fixed shapes: Exported functions only work with the exact input shapes used during export
- No dynamic control flow: Functions with Python control flow that depends on array values cannot be exported
- No closures: Functions that capture variables from outer scopes may not export correctly
- Binary format: Exported files are in MLX’s internal binary format and are not human-readable
- Version compatibility: Exported functions may not be compatible across different MLX versions
Tips
- Export after optimization: Always export after applying
@mx.compile for maximum performance
- Document shapes: Keep track of the input shapes used during export
- Version control: Store export files with version information
- Batch multiple inputs: Export functions that handle batched inputs for flexibility
- Test imports: Always test that imported functions produce correct results
See Also