Implement distributed data parallel training across multiple devices using MLX distributed primitives
MLX enables efficient data parallel distributed training through its distributed communication primitives. In data parallelism, we average gradients across multiple devices before applying them to the model.
To convert this to data parallel training, we need to average the gradients across all devices. We can do this by performing an all_sum and dividing by the group size.
Then apply this function to all gradients using mlx.utils.tree_map:
from mlx.utils import tree_mapdef all_reduce_grads(grads): N = mx.distributed.init().size() if N == 1: return grads return tree_map( lambda x: mx.distributed.all_sum(x) / N, grads )def step(model, x, y): loss, grads = loss_grad_fn(model, x, y) grads = all_reduce_grads(grads) # <--- This line was added optimizer.update(model, grads) return loss
Everything else in the training loop remains the same. Each device computes gradients on its local batch, then all devices synchronize their gradients before updating the model.
The manual implementation above works correctly but performs one communication call per gradient tensor, which can be inefficient. MLX provides mlx.nn.average_gradients() to aggregate several gradients together and perform fewer communication steps.The updated code is nearly identical:
model = ...optimizer = ...dataset = ...def step(model, x, y): loss, grads = loss_grad_fn(model, x, y) grads = nn.average_gradients(grads) # <--- This line was added optimizer.update(model, grads) return lossfor x, y in dataset: loss = step(model, x, y) mx.eval(loss, model.parameters())
nn.average_gradients() automatically batches communication operations for better performance compared to the manual approach.
To run your training script across multiple devices, use MLX’s launch utility:
# Single device (no distribution)python train.py# Two devices with data parallelismmlx.launch -n 2 train.py# Four devices with data parallelismmlx.launch -n 4 train.py
The same code works for both single-device and multi-device training. When running on a single device, nn.average_gradients() simply returns the gradients unchanged.