Skip to main content
MLX enables efficient data parallel distributed training through its distributed communication primitives. In data parallelism, we average gradients across multiple devices before applying them to the model.

Basic Training Loop

Let’s start with a standard MLX training loop before adding distributed training:
model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())

Adding Gradient Averaging

To convert this to data parallel training, we need to average the gradients across all devices. We can do this by performing an all_sum and dividing by the group size.

Manual Implementation

Define a function to average gradients across all devices:
def all_avg(x):
    return mx.distributed.all_sum(x) / mx.distributed.init().size()
Then apply this function to all gradients using mlx.utils.tree_map:
from mlx.utils import tree_map

def all_reduce_grads(grads):
    N = mx.distributed.init().size()
    if N == 1:
        return grads
    return tree_map(
        lambda x: mx.distributed.all_sum(x) / N,
        grads
    )

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = all_reduce_grads(grads)  # <--- This line was added
    optimizer.update(model, grads)
    return loss
Everything else in the training loop remains the same. Each device computes gradients on its local batch, then all devices synchronize their gradients before updating the model.

Using nn.average_gradients

The manual implementation above works correctly but performs one communication call per gradient tensor, which can be inefficient. MLX provides mlx.nn.average_gradients() to aggregate several gradients together and perform fewer communication steps. The updated code is nearly identical:
model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = nn.average_gradients(grads)  # <--- This line was added
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())
nn.average_gradients() automatically batches communication operations for better performance compared to the manual approach.

Complete Example

1

Initialize distributed group

Each process joins the distributed group using mx.distributed.init()
2

Split data across devices

Each device receives a different subset of the training data
3

Compute local gradients

Each device computes gradients on its local batch
4

Average gradients

Use nn.average_gradients() to synchronize and average gradients
5

Update model

All devices update their model copy with the averaged gradients
Here’s a complete training loop with data parallelism:
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map

# Initialize distributed training
world = mx.distributed.init()

# Create model, optimizer, dataset
model = create_model()
optimizer = create_optimizer()
dataset = create_dataset()

# Define loss and gradient function
def loss_fn(model, x, y):
    return nn.losses.cross_entropy(model(x), y)

loss_grad_fn = nn.value_and_grad(model, loss_fn)

# Training step with gradient averaging
def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = nn.average_gradients(grads)
    optimizer.update(model, grads)
    return loss

# Training loop
for epoch in range(num_epochs):
    for x, y in dataset:
        loss = step(model, x, y)
        mx.eval(loss, model.parameters())

Key Considerations

Data Sharding

Each device should process a different subset of the data to maximize efficiency

Synchronization

Gradients are synchronized after each batch, keeping all models in sync

Communication Overhead

Use nn.average_gradients() to batch communications and reduce overhead

Scaling

Data parallelism scales well as the batch size increases with the number of devices

Running Distributed Training

To run your training script across multiple devices, use MLX’s launch utility:
# Single device (no distribution)
python train.py

# Two devices with data parallelism
mlx.launch -n 2 train.py

# Four devices with data parallelism
mlx.launch -n 4 train.py
The same code works for both single-device and multi-device training. When running on a single device, nn.average_gradients() simply returns the gradients unchanged.

Build docs developers (and LLMs) love