MLX provides utilities for working with Python trees - arbitrarily nested collections of dictionaries, lists, and tuples. These utilities are essential for working with model parameters, gradients, and complex data structures.
Overview
A Python tree in MLX is any nested combination of:
- Dictionaries (
dict)
- Lists (
list)
- Tuples (
tuple)
- Leaf values (arrays, scalars, etc.)
Trees must not contain cycles.
Dictionaries in trees should have keys that are valid Python identifiers (e.g., "weight", "bias_1" not "weight-1" or "my key").
Functions
tree_flatten
mlx.utils.tree_flatten(
tree: Any,
prefix: str = "",
is_leaf: Optional[Callable] = None,
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]
Flatten a Python tree into a list of (key, value) tuples or a dictionary.
Keys use dot notation to represent the tree structure.
Parameters:
tree (Any): The Python tree to flatten
prefix (str): Prefix for keys (first character is discarded). Default: ""
is_leaf (callable, optional): Function that returns True for leaf nodes
destination (list or dict, optional): Container for flattened tree. Default: None (creates a list)
Returns:
- List of (key, value) tuples or dictionary with flattened tree
Example:
from mlx.utils import tree_flatten
import mlx.core as mx
# Simple nested structure
tree = {
"layer1": {
"weight": mx.array([1, 2, 3]),
"bias": mx.array([0])
},
"layer2": {
"weight": mx.array([4, 5, 6])
}
}
flat = tree_flatten(tree)
print(flat)
# [('layer1.weight', array([1, 2, 3])),
# ('layer1.bias', array([0])),
# ('layer2.weight', array([4, 5, 6]))]
With lists and tuples:
from mlx.utils import tree_flatten
# Nested list
tree = [[[0]]]
print(tree_flatten(tree))
# [('0.0.0', 0)]
# With prefix
print(tree_flatten([[[0]]], prefix=".model"))
# [('model.0.0.0', 0)]
Flatten to dictionary:
from mlx.utils import tree_flatten
tree = {"a": {"b": 1, "c": 2}}
flat_dict = tree_flatten(tree, destination={})
print(flat_dict)
# {'a.b': 1, 'a.c': 2}
With custom leaf function:
from mlx.utils import tree_flatten
import mlx.core as mx
# Treat arrays as leaves, but flatten dicts
tree = {
"params": [mx.array([1, 2]), mx.array([3, 4])],
"state": {"step": 0}
}
flat = tree_flatten(tree, is_leaf=lambda x: isinstance(x, mx.array))
for key, value in flat:
print(f"{key}: {value}")
# params.0: array([1, 2])
# params.1: array([3, 4])
# state.step: 0
tree_unflatten
mlx.utils.tree_unflatten(
tree: Union[List[Tuple[str, Any]], Dict[str, Any]]
) -> Any
Recreate a Python tree from its flattened representation.
Parameters:
tree (list or dict): Flattened tree from tree_flatten()
Returns:
- Reconstructed Python tree
Example:
from mlx.utils import tree_flatten, tree_unflatten
# Flatten and unflatten
original = {"a": {"b": 1, "c": 2}, "d": [3, 4]}
flat = tree_flatten(original)
reconstructed = tree_unflatten(flat)
print(reconstructed)
# {'a': {'b': 1, 'c': 2}, 'd': [3, 4]}
From dictionary:
from mlx.utils import tree_unflatten
flat_dict = {"hello.world": 42}
tree = tree_unflatten(flat_dict)
print(tree)
# {'hello': {'world': 42}}
Reconstructing model parameters:
from mlx.utils import tree_flatten, tree_unflatten
import mlx.nn as nn
import mlx.core as mx
# Save model parameters
model = nn.Linear(10, 5)
flat_params = tree_flatten(model.parameters())
# Save to file (simplified)
import pickle
with open("params.pkl", "wb") as f:
pickle.dump(flat_params, f)
# Load and reconstruct
with open("params.pkl", "rb") as f:
loaded_flat = pickle.load(f)
params = tree_unflatten(loaded_flat)
model.update(params)
tree_map
mlx.utils.tree_map(
fn: Callable,
tree: Any,
*rest: Any,
is_leaf: Optional[Callable] = None
) -> Any
Apply a function to all leaves of a Python tree.
If additional trees are provided, the function is called with corresponding leaves from all trees.
Parameters:
fn (callable): Function to apply to leaves
tree (Any): The main tree to iterate over
*rest (Any): Additional trees with matching structure
is_leaf (callable, optional): Function to determine leaf nodes
Returns:
- New tree with function applied to all leaves
Example:
from mlx.utils import tree_map
import mlx.nn as nn
import mlx.core as mx
# Square all parameters
model = nn.Linear(10, 5)
params = model.parameters()
squared_params = tree_map(lambda x: x * x, params)
print(squared_params.keys())
# dict_keys(['weight', 'bias'])
Element-wise operations on multiple trees:
from mlx.utils import tree_map
import mlx.core as mx
# Add corresponding elements from two trees
tree1 = {"a": mx.array([1, 2]), "b": mx.array([3, 4])}
tree2 = {"a": mx.array([5, 6]), "b": mx.array([7, 8])}
result = tree_map(lambda x, y: x + y, tree1, tree2)
print(result)
# {'a': array([6, 8]), 'b': array([10, 12])}
Gradient clipping:
from mlx.utils import tree_map
import mlx.core as mx
def clip_gradients(grads, max_norm=1.0):
"""Clip all gradients to maximum norm."""
return tree_map(
lambda g: mx.clip(g, -max_norm, max_norm),
grads
)
grads = {
"layer1": {"weight": mx.array([2.5, -3.0])},
"layer2": {"weight": mx.array([0.5, 1.5])}
}
clipped = clip_gradients(grads, max_norm=2.0)
print(clipped)
# {'layer1': {'weight': array([2.0, -2.0])},
# 'layer2': {'weight': array([0.5, 1.5])}}
Moving parameters to different dtype:
from mlx.utils import tree_map
import mlx.nn as nn
import mlx.core as mx
model = nn.Linear(100, 50)
# Convert all parameters to float16
fp16_params = tree_map(lambda x: x.astype(mx.float16), model.parameters())
model.update(fp16_params)
tree_map_with_path
mlx.utils.tree_map_with_path(
fn: Callable,
tree: Any,
*rest: Any,
is_leaf: Optional[Callable] = None,
path: Optional[Any] = None
) -> Any
Apply a function to tree leaves with their paths.
Like tree_map, but the function receives the path as the first argument.
Parameters:
fn (callable): Function taking (path, *values) as arguments
tree (Any): The main tree to iterate over
*rest (Any): Additional trees with matching structure
is_leaf (callable, optional): Function to determine leaf nodes
path (Any, optional): Prefix for paths
Returns:
- New tree with function applied to all leaves
Example:
from mlx.utils import tree_map_with_path
import mlx.core as mx
tree = {
"encoder": {
"layer1": mx.array([1, 2]),
"layer2": mx.array([3, 4])
},
"decoder": {
"output": mx.array([5, 6])
}
}
# Print all paths
result = tree_map_with_path(lambda path, x: print(path) or x, tree)
# encoder.layer1
# encoder.layer2
# decoder.output
Selective parameter updates:
from mlx.utils import tree_map_with_path
import mlx.core as mx
def selective_update(path, param, grad, lr=0.01):
"""Only update encoder parameters."""
if path.startswith("encoder"):
return param - lr * grad
return param
params = {
"encoder": {"weight": mx.array([1.0])},
"decoder": {"weight": mx.array([2.0])}
}
grads = {
"encoder": {"weight": mx.array([0.1])},
"decoder": {"weight": mx.array([0.2])}
}
updated = tree_map_with_path(selective_update, params, grads)
print(updated)
# {'encoder': {'weight': array([0.999])},
# 'decoder': {'weight': array([2.0])}} # unchanged
Layer-wise learning rates:
from mlx.utils import tree_map_with_path
import mlx.core as mx
def layer_wise_lr(path, param, grad):
"""Different learning rate per layer."""
if "layer1" in path:
lr = 0.1
elif "layer2" in path:
lr = 0.01
else:
lr = 0.001
return param - lr * grad
params = {
"layer1": {"weight": mx.array([1.0])},
"layer2": {"weight": mx.array([1.0])},
"output": {"weight": mx.array([1.0])}
}
grads = tree_map(lambda x: mx.ones_like(x), params)
updated = tree_map_with_path(layer_wise_lr, params, grads)
for path, value in tree_flatten(updated):
print(f"{path}: {value}")
# layer1.weight: [0.9] (lr=0.1)
# layer2.weight: [0.99] (lr=0.01)
# output.weight: [0.999] (lr=0.001)
tree_reduce
mlx.utils.tree_reduce(
fn: Callable,
tree: Any,
initializer: Any = None,
is_leaf: Optional[Callable] = None
) -> Any
Reduce a tree to a single value by applying a binary function.
Parameters:
fn (callable): Binary function taking (accumulator, value)
tree (Any): The tree to reduce
initializer (Any, optional): Initial accumulator value
is_leaf (callable, optional): Function to determine leaf nodes
Returns:
Example:
from mlx.utils import tree_reduce
# Sum all values in tree
tree = {"a": [1, 2, 3], "b": [4, 5]}
total = tree_reduce(lambda acc, x: acc + x, tree, 0)
print(total) # 15
Count parameters:
from mlx.utils import tree_reduce
import mlx.nn as nn
import mlx.core as mx
def count_parameters(model):
"""Count total number of parameters."""
params = model.parameters()
return tree_reduce(
lambda acc, x: acc + x.size,
params,
0
)
model = nn.Sequential(
nn.Linear(784, 256),
nn.Linear(256, 10)
)
num_params = count_parameters(model)
print(f"Total parameters: {num_params:,}") # Total parameters: 202,890
Computing gradient norm:
from mlx.utils import tree_reduce
import mlx.core as mx
def gradient_norm(grads):
"""Compute L2 norm of all gradients."""
sum_of_squares = tree_reduce(
lambda acc, g: acc + mx.sum(g * g),
grads,
mx.array(0.0)
)
return mx.sqrt(sum_of_squares)
grads = {
"layer1": {"weight": mx.array([1.0, 2.0])},
"layer2": {"weight": mx.array([3.0, 4.0])}
}
norm = gradient_norm(grads)
print(f"Gradient norm: {norm.item():.2f}") # 5.48
Finding maximum value:
from mlx.utils import tree_reduce
import mlx.core as mx
def tree_max(tree):
"""Find maximum value across all arrays in tree."""
return tree_reduce(
lambda acc, x: mx.maximum(acc, mx.max(x)),
tree,
mx.array(-float('inf'))
)
params = {
"layer1": mx.array([1.0, 5.0, 3.0]),
"layer2": mx.array([2.0, 8.0, 1.0]) # max here
}
max_val = tree_max(params)
print(f"Maximum value: {max_val.item()}") # 8.0
Practical Examples
Parameter Management
from mlx.utils import tree_map, tree_flatten, tree_unflatten
import mlx.nn as nn
import mlx.core as mx
class ParameterManager:
@staticmethod
def save(model, path):
"""Save model parameters to file."""
params = model.parameters()
flat = tree_flatten(params, destination={})
mx.savez(path, **flat)
@staticmethod
def load(model, path):
"""Load model parameters from file."""
flat_dict = dict(mx.load(path))
params = tree_unflatten(flat_dict)
model.update(params)
@staticmethod
def count(model):
"""Count total parameters."""
from mlx.utils import tree_reduce
params = model.parameters()
return tree_reduce(lambda acc, x: acc + x.size, params, 0)
@staticmethod
def to_dtype(model, dtype):
"""Convert all parameters to dtype."""
params = model.parameters()
new_params = tree_map(lambda x: x.astype(dtype), params)
model.update(new_params)
# Usage
model = nn.Linear(100, 50)
ParameterManager.save(model, "model.npz")
print(f"Parameters: {ParameterManager.count(model):,}")
ParameterManager.to_dtype(model, mx.float16)
Optimizer with Tree Utils
from mlx.utils import tree_map
import mlx.core as mx
class SimpleOptimizer:
def __init__(self, learning_rate=0.01):
self.lr = learning_rate
def update(self, model, gradients):
"""Update model parameters using gradients."""
params = model.parameters()
# Update: param = param - lr * grad
new_params = tree_map(
lambda p, g: p - self.lr * g,
params,
gradients
)
model.update(new_params)
optimizer = SimpleOptimizer(learning_rate=0.01)
optimizer.update(model, grads)
Gradient Utilities
from mlx.utils import tree_map, tree_reduce
import mlx.core as mx
class GradientUtils:
@staticmethod
def clip_by_norm(grads, max_norm=1.0):
"""Clip gradients by global norm."""
# Compute global norm
norm = mx.sqrt(tree_reduce(
lambda acc, g: acc + mx.sum(g * g),
grads,
mx.array(0.0)
))
# Clip if necessary
scale = mx.minimum(max_norm / (norm + 1e-6), 1.0)
return tree_map(lambda g: g * scale, grads)
@staticmethod
def clip_by_value(grads, min_val=-1.0, max_val=1.0):
"""Clip each gradient by value."""
return tree_map(
lambda g: mx.clip(g, min_val, max_val),
grads
)
@staticmethod
def add_noise(grads, noise_scale=0.01):
"""Add Gaussian noise to gradients."""
return tree_map(
lambda g: g + mx.random.normal(g.shape) * noise_scale,
grads
)
# Usage
grads = compute_gradients(model, batch)
grads = GradientUtils.clip_by_norm(grads, max_norm=1.0)
grads = GradientUtils.add_noise(grads, noise_scale=0.001)
optimizer.update(model, grads)
Tips
- Use tree_map for bulk operations: Much faster than manual iteration
- Flatten for serialization:
tree_flatten makes saving parameters easy
- Custom is_leaf for arrays: Treat MLX arrays as leaves in most cases
- tree_map_with_path for debugging: Print paths to understand structure
- Combine with model.parameters(): All neural network parameters are trees
See Also