Skip to main content
MLX provides utilities for working with Python trees - arbitrarily nested collections of dictionaries, lists, and tuples. These utilities are essential for working with model parameters, gradients, and complex data structures.

Overview

A Python tree in MLX is any nested combination of:
  • Dictionaries (dict)
  • Lists (list)
  • Tuples (tuple)
  • Leaf values (arrays, scalars, etc.)
Trees must not contain cycles.
Dictionaries in trees should have keys that are valid Python identifiers (e.g., "weight", "bias_1" not "weight-1" or "my key").

Functions

tree_flatten

mlx.utils.tree_flatten(
    tree: Any,
    prefix: str = "",
    is_leaf: Optional[Callable] = None,
    destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]
Flatten a Python tree into a list of (key, value) tuples or a dictionary. Keys use dot notation to represent the tree structure. Parameters:
  • tree (Any): The Python tree to flatten
  • prefix (str): Prefix for keys (first character is discarded). Default: ""
  • is_leaf (callable, optional): Function that returns True for leaf nodes
  • destination (list or dict, optional): Container for flattened tree. Default: None (creates a list)
Returns:
  • List of (key, value) tuples or dictionary with flattened tree
Example:
from mlx.utils import tree_flatten
import mlx.core as mx

# Simple nested structure
tree = {
    "layer1": {
        "weight": mx.array([1, 2, 3]),
        "bias": mx.array([0])
    },
    "layer2": {
        "weight": mx.array([4, 5, 6])
    }
}

flat = tree_flatten(tree)
print(flat)
# [('layer1.weight', array([1, 2, 3])),
#  ('layer1.bias', array([0])),
#  ('layer2.weight', array([4, 5, 6]))]
With lists and tuples:
from mlx.utils import tree_flatten

# Nested list
tree = [[[0]]]
print(tree_flatten(tree))
# [('0.0.0', 0)]

# With prefix
print(tree_flatten([[[0]]], prefix=".model"))
# [('model.0.0.0', 0)]
Flatten to dictionary:
from mlx.utils import tree_flatten

tree = {"a": {"b": 1, "c": 2}}
flat_dict = tree_flatten(tree, destination={})
print(flat_dict)
# {'a.b': 1, 'a.c': 2}
With custom leaf function:
from mlx.utils import tree_flatten
import mlx.core as mx

# Treat arrays as leaves, but flatten dicts
tree = {
    "params": [mx.array([1, 2]), mx.array([3, 4])],
    "state": {"step": 0}
}

flat = tree_flatten(tree, is_leaf=lambda x: isinstance(x, mx.array))
for key, value in flat:
    print(f"{key}: {value}")
# params.0: array([1, 2])
# params.1: array([3, 4])
# state.step: 0

tree_unflatten

mlx.utils.tree_unflatten(
    tree: Union[List[Tuple[str, Any]], Dict[str, Any]]
) -> Any
Recreate a Python tree from its flattened representation. Parameters:
  • tree (list or dict): Flattened tree from tree_flatten()
Returns:
  • Reconstructed Python tree
Example:
from mlx.utils import tree_flatten, tree_unflatten

# Flatten and unflatten
original = {"a": {"b": 1, "c": 2}, "d": [3, 4]}
flat = tree_flatten(original)
reconstructed = tree_unflatten(flat)

print(reconstructed)
# {'a': {'b': 1, 'c': 2}, 'd': [3, 4]}
From dictionary:
from mlx.utils import tree_unflatten

flat_dict = {"hello.world": 42}
tree = tree_unflatten(flat_dict)
print(tree)
# {'hello': {'world': 42}}
Reconstructing model parameters:
from mlx.utils import tree_flatten, tree_unflatten
import mlx.nn as nn
import mlx.core as mx

# Save model parameters
model = nn.Linear(10, 5)
flat_params = tree_flatten(model.parameters())

# Save to file (simplified)
import pickle
with open("params.pkl", "wb") as f:
    pickle.dump(flat_params, f)

# Load and reconstruct
with open("params.pkl", "rb") as f:
    loaded_flat = pickle.load(f)

params = tree_unflatten(loaded_flat)
model.update(params)

tree_map

mlx.utils.tree_map(
    fn: Callable,
    tree: Any,
    *rest: Any,
    is_leaf: Optional[Callable] = None
) -> Any
Apply a function to all leaves of a Python tree. If additional trees are provided, the function is called with corresponding leaves from all trees. Parameters:
  • fn (callable): Function to apply to leaves
  • tree (Any): The main tree to iterate over
  • *rest (Any): Additional trees with matching structure
  • is_leaf (callable, optional): Function to determine leaf nodes
Returns:
  • New tree with function applied to all leaves
Example:
from mlx.utils import tree_map
import mlx.nn as nn
import mlx.core as mx

# Square all parameters
model = nn.Linear(10, 5)
params = model.parameters()
squared_params = tree_map(lambda x: x * x, params)

print(squared_params.keys())
# dict_keys(['weight', 'bias'])
Element-wise operations on multiple trees:
from mlx.utils import tree_map
import mlx.core as mx

# Add corresponding elements from two trees
tree1 = {"a": mx.array([1, 2]), "b": mx.array([3, 4])}
tree2 = {"a": mx.array([5, 6]), "b": mx.array([7, 8])}

result = tree_map(lambda x, y: x + y, tree1, tree2)
print(result)
# {'a': array([6, 8]), 'b': array([10, 12])}
Gradient clipping:
from mlx.utils import tree_map
import mlx.core as mx

def clip_gradients(grads, max_norm=1.0):
    """Clip all gradients to maximum norm."""
    return tree_map(
        lambda g: mx.clip(g, -max_norm, max_norm),
        grads
    )

grads = {
    "layer1": {"weight": mx.array([2.5, -3.0])},
    "layer2": {"weight": mx.array([0.5, 1.5])}
}

clipped = clip_gradients(grads, max_norm=2.0)
print(clipped)
# {'layer1': {'weight': array([2.0, -2.0])},
#  'layer2': {'weight': array([0.5, 1.5])}}
Moving parameters to different dtype:
from mlx.utils import tree_map
import mlx.nn as nn
import mlx.core as mx

model = nn.Linear(100, 50)

# Convert all parameters to float16
fp16_params = tree_map(lambda x: x.astype(mx.float16), model.parameters())
model.update(fp16_params)

tree_map_with_path

mlx.utils.tree_map_with_path(
    fn: Callable,
    tree: Any,
    *rest: Any,
    is_leaf: Optional[Callable] = None,
    path: Optional[Any] = None
) -> Any
Apply a function to tree leaves with their paths. Like tree_map, but the function receives the path as the first argument. Parameters:
  • fn (callable): Function taking (path, *values) as arguments
  • tree (Any): The main tree to iterate over
  • *rest (Any): Additional trees with matching structure
  • is_leaf (callable, optional): Function to determine leaf nodes
  • path (Any, optional): Prefix for paths
Returns:
  • New tree with function applied to all leaves
Example:
from mlx.utils import tree_map_with_path
import mlx.core as mx

tree = {
    "encoder": {
        "layer1": mx.array([1, 2]),
        "layer2": mx.array([3, 4])
    },
    "decoder": {
        "output": mx.array([5, 6])
    }
}

# Print all paths
result = tree_map_with_path(lambda path, x: print(path) or x, tree)
# encoder.layer1
# encoder.layer2
# decoder.output
Selective parameter updates:
from mlx.utils import tree_map_with_path
import mlx.core as mx

def selective_update(path, param, grad, lr=0.01):
    """Only update encoder parameters."""
    if path.startswith("encoder"):
        return param - lr * grad
    return param

params = {
    "encoder": {"weight": mx.array([1.0])},
    "decoder": {"weight": mx.array([2.0])}
}
grads = {
    "encoder": {"weight": mx.array([0.1])},
    "decoder": {"weight": mx.array([0.2])}
}

updated = tree_map_with_path(selective_update, params, grads)
print(updated)
# {'encoder': {'weight': array([0.999])},
#  'decoder': {'weight': array([2.0])}}  # unchanged
Layer-wise learning rates:
from mlx.utils import tree_map_with_path
import mlx.core as mx

def layer_wise_lr(path, param, grad):
    """Different learning rate per layer."""
    if "layer1" in path:
        lr = 0.1
    elif "layer2" in path:
        lr = 0.01
    else:
        lr = 0.001
    
    return param - lr * grad

params = {
    "layer1": {"weight": mx.array([1.0])},
    "layer2": {"weight": mx.array([1.0])},
    "output": {"weight": mx.array([1.0])}
}
grads = tree_map(lambda x: mx.ones_like(x), params)

updated = tree_map_with_path(layer_wise_lr, params, grads)
for path, value in tree_flatten(updated):
    print(f"{path}: {value}")
# layer1.weight: [0.9]      (lr=0.1)
# layer2.weight: [0.99]     (lr=0.01)
# output.weight: [0.999]    (lr=0.001)

tree_reduce

mlx.utils.tree_reduce(
    fn: Callable,
    tree: Any,
    initializer: Any = None,
    is_leaf: Optional[Callable] = None
) -> Any
Reduce a tree to a single value by applying a binary function. Parameters:
  • fn (callable): Binary function taking (accumulator, value)
  • tree (Any): The tree to reduce
  • initializer (Any, optional): Initial accumulator value
  • is_leaf (callable, optional): Function to determine leaf nodes
Returns:
  • Single accumulated value
Example:
from mlx.utils import tree_reduce

# Sum all values in tree
tree = {"a": [1, 2, 3], "b": [4, 5]}
total = tree_reduce(lambda acc, x: acc + x, tree, 0)
print(total)  # 15
Count parameters:
from mlx.utils import tree_reduce
import mlx.nn as nn
import mlx.core as mx

def count_parameters(model):
    """Count total number of parameters."""
    params = model.parameters()
    return tree_reduce(
        lambda acc, x: acc + x.size,
        params,
        0
    )

model = nn.Sequential(
    nn.Linear(784, 256),
    nn.Linear(256, 10)
)

num_params = count_parameters(model)
print(f"Total parameters: {num_params:,}")  # Total parameters: 202,890
Computing gradient norm:
from mlx.utils import tree_reduce
import mlx.core as mx

def gradient_norm(grads):
    """Compute L2 norm of all gradients."""
    sum_of_squares = tree_reduce(
        lambda acc, g: acc + mx.sum(g * g),
        grads,
        mx.array(0.0)
    )
    return mx.sqrt(sum_of_squares)

grads = {
    "layer1": {"weight": mx.array([1.0, 2.0])},
    "layer2": {"weight": mx.array([3.0, 4.0])}
}

norm = gradient_norm(grads)
print(f"Gradient norm: {norm.item():.2f}")  # 5.48
Finding maximum value:
from mlx.utils import tree_reduce
import mlx.core as mx

def tree_max(tree):
    """Find maximum value across all arrays in tree."""
    return tree_reduce(
        lambda acc, x: mx.maximum(acc, mx.max(x)),
        tree,
        mx.array(-float('inf'))
    )

params = {
    "layer1": mx.array([1.0, 5.0, 3.0]),
    "layer2": mx.array([2.0, 8.0, 1.0])  # max here
}

max_val = tree_max(params)
print(f"Maximum value: {max_val.item()}")  # 8.0

Practical Examples

Parameter Management

from mlx.utils import tree_map, tree_flatten, tree_unflatten
import mlx.nn as nn
import mlx.core as mx

class ParameterManager:
    @staticmethod
    def save(model, path):
        """Save model parameters to file."""
        params = model.parameters()
        flat = tree_flatten(params, destination={})
        mx.savez(path, **flat)
    
    @staticmethod
    def load(model, path):
        """Load model parameters from file."""
        flat_dict = dict(mx.load(path))
        params = tree_unflatten(flat_dict)
        model.update(params)
    
    @staticmethod
    def count(model):
        """Count total parameters."""
        from mlx.utils import tree_reduce
        params = model.parameters()
        return tree_reduce(lambda acc, x: acc + x.size, params, 0)
    
    @staticmethod
    def to_dtype(model, dtype):
        """Convert all parameters to dtype."""
        params = model.parameters()
        new_params = tree_map(lambda x: x.astype(dtype), params)
        model.update(new_params)

# Usage
model = nn.Linear(100, 50)
ParameterManager.save(model, "model.npz")
print(f"Parameters: {ParameterManager.count(model):,}")
ParameterManager.to_dtype(model, mx.float16)

Optimizer with Tree Utils

from mlx.utils import tree_map
import mlx.core as mx

class SimpleOptimizer:
    def __init__(self, learning_rate=0.01):
        self.lr = learning_rate
    
    def update(self, model, gradients):
        """Update model parameters using gradients."""
        params = model.parameters()
        
        # Update: param = param - lr * grad
        new_params = tree_map(
            lambda p, g: p - self.lr * g,
            params,
            gradients
        )
        
        model.update(new_params)

optimizer = SimpleOptimizer(learning_rate=0.01)
optimizer.update(model, grads)

Gradient Utilities

from mlx.utils import tree_map, tree_reduce
import mlx.core as mx

class GradientUtils:
    @staticmethod
    def clip_by_norm(grads, max_norm=1.0):
        """Clip gradients by global norm."""
        # Compute global norm
        norm = mx.sqrt(tree_reduce(
            lambda acc, g: acc + mx.sum(g * g),
            grads,
            mx.array(0.0)
        ))
        
        # Clip if necessary
        scale = mx.minimum(max_norm / (norm + 1e-6), 1.0)
        return tree_map(lambda g: g * scale, grads)
    
    @staticmethod
    def clip_by_value(grads, min_val=-1.0, max_val=1.0):
        """Clip each gradient by value."""
        return tree_map(
            lambda g: mx.clip(g, min_val, max_val),
            grads
        )
    
    @staticmethod
    def add_noise(grads, noise_scale=0.01):
        """Add Gaussian noise to gradients."""
        return tree_map(
            lambda g: g + mx.random.normal(g.shape) * noise_scale,
            grads
        )

# Usage
grads = compute_gradients(model, batch)
grads = GradientUtils.clip_by_norm(grads, max_norm=1.0)
grads = GradientUtils.add_noise(grads, noise_scale=0.001)
optimizer.update(model, grads)

Tips

  1. Use tree_map for bulk operations: Much faster than manual iteration
  2. Flatten for serialization: tree_flatten makes saving parameters easy
  3. Custom is_leaf for arrays: Treat MLX arrays as leaves in most cases
  4. tree_map_with_path for debugging: Print paths to understand structure
  5. Combine with model.parameters(): All neural network parameters are trees

See Also

Build docs developers (and LLMs) love