Skip to main content
MLX provides functions to export compiled computation graphs to files and import them back. This is useful for:
  • Deploying optimized models
  • Sharing pre-compiled functions
  • Debugging computation graphs
  • Understanding how MLX transforms your code

Overview

Exported functions preserve the compiled and optimized computation graph, which can be imported and executed later without recompilation.

Functions

export_function

mlx.core.export_function(
    fun: callable,
    *args,
    file: str
) -> None
Export a compiled MLX function to a file. The function is traced with the provided arguments and the resulting computation graph is saved. Parameters:
  • fun (callable): Function to export
  • *args: Example arguments used to trace the function
  • file (str): Path to save the exported function
Example:
import mlx.core as mx

def matmul_and_add(A, B, C):
    return (A @ B) + C

# Trace and export the function
A = mx.random.normal((100, 100))
B = mx.random.normal((100, 100))
C = mx.random.normal((100, 100))

mx.export_function(
    matmul_and_add,
    A, B, C,
    file="matmul_add.mlx"
)

print("Function exported to matmul_add.mlx")
With function transforms:
import mlx.core as mx
import mlx.nn as nn

def loss_fn(model, x, y):
    return mx.mean((model(x) - y) ** 2)

# Export the compiled gradient function
model = create_model()
x = mx.random.normal((32, 10))
y = mx.random.normal((32, 1))

loss_and_grad = mx.value_and_grad(loss_fn)

mx.export_function(
    loss_and_grad,
    model, x, y,
    file="loss_and_grad.mlx"
)
The exported function is tied to the specific input shapes and dtypes used during export. Attempting to import and call with different shapes will result in an error.

import_function

mlx.core.import_function(file: str) -> callable
Import a previously exported MLX function. Returns a callable that executes the pre-compiled computation graph. Parameters:
  • file (str): Path to the exported function file
Returns:
  • Callable function that executes the imported graph
Example:
import mlx.core as mx

# Import the previously exported function
matmul_and_add = mx.import_function("matmul_add.mlx")

# Use it with matching shapes
A = mx.random.normal((100, 100))
B = mx.random.normal((100, 100))
C = mx.random.normal((100, 100))

result = matmul_and_add(A, B, C)
print(result.shape)  # (100, 100)
Deployment example:
import mlx.core as mx

# In production, just import the pre-compiled function
predict = mx.import_function("model_inference.mlx")

# Run inference without recompilation
for batch in data_loader:
    predictions = predict(batch)
    process_predictions(predictions)

exporter

@mlx.core.exporter(file: str)
def function(*args):
    ...
Decorator to automatically export a function when it’s first called. The function is traced with the first set of arguments it receives and exported to the specified file. Parameters:
  • file (str): Path to save the exported function
Example:
import mlx.core as mx

@mx.exporter("my_function.mlx")
def complex_computation(x, y):
    z = mx.exp(x) + mx.log(y)
    return mx.sin(z) * mx.cos(z)

# First call traces and exports
x = mx.array([1.0, 2.0, 3.0])
y = mx.array([4.0, 5.0, 6.0])
result = complex_computation(x, y)

# Function is now exported to my_function.mlx
With neural network layers:
import mlx.core as mx
import mlx.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [
            nn.Linear(784, 256),
            nn.Linear(256, 10)
        ]
    
    @mx.exporter("model_forward.mlx")
    def __call__(self, x):
        for layer in self.layers:
            x = mx.relu(layer(x))
        return x

model = MyModel()
x = mx.random.normal((32, 784))
output = model(x)  # Exports on first call

export_to_dot

mlx.core.export_to_dot(
    fun: callable,
    *args
) -> str
Export the computation graph as a DOT format string. Useful for visualizing and debugging MLX computation graphs. Parameters:
  • fun (callable): Function to trace
  • *args: Arguments used to trace the function
Returns:
  • String containing the graph in DOT format
Example:
import mlx.core as mx

def example_fn(x, y):
    a = x + y
    b = a * 2
    return mx.sum(b)

x = mx.array([1.0, 2.0, 3.0])
y = mx.array([4.0, 5.0, 6.0])

# Get DOT representation
dot_string = mx.export_to_dot(example_fn, x, y)
print(dot_string)
Visualizing the graph:
import mlx.core as mx
import subprocess

def visualize_graph(fun, *args, output_file="graph.png"):
    dot_string = mx.export_to_dot(fun, *args)
    
    # Save to file and render with graphviz
    with open("graph.dot", "w") as f:
        f.write(dot_string)
    
    subprocess.run([
        "dot", "-Tpng", "graph.dot",
        "-o", output_file
    ])
    print(f"Graph saved to {output_file}")

# Visualize a complex function
def model_forward(x, w1, w2):
    h = mx.tanh(x @ w1)
    return h @ w2

x = mx.random.normal((10, 5))
w1 = mx.random.normal((5, 8))
w2 = mx.random.normal((8, 3))

visualize_graph(model_forward, x, w1, w2)
Debugging compilation:
import mlx.core as mx

@mx.compile
def optimized_fn(x):
    y = x * 2
    z = y + 3
    return mx.sum(z)

# See what the compiled graph looks like
x = mx.array([1.0, 2.0, 3.0])
dot = mx.export_to_dot(optimized_fn, x)

# Check for fusion opportunities
if "fused" in dot.lower():
    print("Operations were fused!")
else:
    print("No fusion detected")

Use Cases

Model Deployment

Export a trained model for efficient deployment:
import mlx.core as mx
import mlx.nn as nn

# Train your model
model = train_model()

# Export inference function
def inference(x):
    return model(x)

# Export with example input shape
example_input = mx.zeros((1, 224, 224, 3))
mx.export_function(
    inference,
    example_input,
    file="model_inference.mlx"
)

# In production
infer = mx.import_function("model_inference.mlx")
result = infer(input_batch)

Debugging Performance

Visualize what operations are being performed:
import mlx.core as mx

def slow_function(x):
    # Multiple small operations
    for _ in range(100):
        x = x + 1
    return x

def fast_function(x):
    # Single fused operation
    return x + 100

# Compare graphs
x = mx.array([1.0])
print("Slow version:")
print(mx.export_to_dot(slow_function, x))

print("\nFast version:")
print(mx.export_to_dot(fast_function, x))

Sharing Compiled Functions

Share optimized functions with team members:
import mlx.core as mx

# Data scientist exports optimized preprocessing
def preprocess(images):
    images = images / 255.0
    images = images - mx.array([0.485, 0.456, 0.406])
    images = images / mx.array([0.229, 0.224, 0.225])
    return images

example = mx.zeros((1, 224, 224, 3))
mx.export_function(
    preprocess,
    example,
    file="preprocessing.mlx"
)

# Engineer imports and uses it
preprocess = mx.import_function("preprocessing.mlx")
processed_batch = preprocess(raw_batch)

Limitations

Important limitations when using export/import:
  1. Fixed shapes: Exported functions only work with the exact input shapes used during export
  2. No dynamic control flow: Functions with Python control flow that depends on array values cannot be exported
  3. No closures: Functions that capture variables from outer scopes may not export correctly
  4. Binary format: Exported files are in MLX’s internal binary format and are not human-readable
  5. Version compatibility: Exported functions may not be compatible across different MLX versions

Tips

  1. Export after optimization: Always export after applying @mx.compile for maximum performance
  2. Document shapes: Keep track of the input shapes used during export
  3. Version control: Store export files with version information
  4. Batch multiple inputs: Export functions that handle batched inputs for flexibility
  5. Test imports: Always test that imported functions produce correct results

See Also

Build docs developers (and LLMs) love