MLX provides distributed communication operations that allow computational workloads to be shared across multiple physical machines or GPUs. The distributed module supports several backends including MPI, Ring, JACCL (Thunderbolt RDMA), and NCCL.
Overview
Distributed communication in MLX enables:
- Data parallelism for training large models
- Tensor parallelism for models too large for a single device
- Efficient multi-node inference
- Collective operations like all-reduce and all-gather
Getting Started
A basic distributed program in MLX:
import mlx.core as mx
# Initialize distributed backend
world = mx.distributed.init()
# Sum array across all processes
x = mx.distributed.all_sum(mx.ones(10))
print(f"Rank {world.rank()}: {x}")
Running the program:
# Run on 4 local processes
mlx.launch -n 4 my_script.py
# Run on remote hosts
mlx.launch --hosts host1,host2,host3 my_script.py
Classes
Group
mlx.core.distributed.Group
Represents a group of processes in distributed communication.
Methods:
rank
Returns the rank of the current process in the group.
Example:
import mlx.core as mx
world = mx.distributed.init()
print(f"My rank is: {world.rank()}")
size
Returns the total number of processes in the group.
Example:
import mlx.core as mx
world = mx.distributed.init()
print(f"Total processes: {world.size()}")
Functions
is_available
mlx.core.distributed.is_available(backend: str = "any") -> bool
Check if a distributed backend is available.
Parameters:
backend (str): Backend name. Options: "any", "mpi", "ring", "jaccl", "nccl". Default: "any"
Returns:
True if the backend is available, False otherwise
Example:
import mlx.core as mx
if mx.distributed.is_available("mpi"):
print("MPI backend is available")
if mx.distributed.is_available("jaccl"):
print("JACCL (Thunderbolt RDMA) is available")
init
mlx.core.distributed.init(backend: str = "any") -> Group
Initialize the distributed backend and return the world group.
Parameters:
backend (str): Backend to use. Options: "any", "mpi", "ring", "jaccl", "nccl". Default: "any"
Returns:
- Group object representing all processes
Example:
import mlx.core as mx
# Initialize any available backend
world = mx.distributed.init()
# Initialize specific backend
mpi_world = mx.distributed.init(backend="mpi")
jaccl_world = mx.distributed.init(backend="jaccl")
After a distributed backend is successfully initialized, subsequent calls to init() with backend="any" will return the same backend, not initialize a new one.
all_sum
mlx.core.distributed.all_sum(x: array) -> array
Sum the input array across all processes.
Each process receives the total sum of all input arrays.
Parameters:
x (array): Input array to sum
Returns:
- Array containing the sum across all processes
Example:
import mlx.core as mx
world = mx.distributed.init()
# Each process has value equal to its rank
x = mx.full((10,), float(world.rank()))
# Sum across all processes
total = mx.distributed.all_sum(x)
# If world.size() == 4, each process now has [0+1+2+3, ...] = [6, 6, ...]
print(f"Rank {world.rank()}: {total}")
Common use case - Gradient synchronization:
import mlx.core as mx
import mlx.nn as nn
world = mx.distributed.init()
# Compute gradients on local batch
loss, grads = loss_fn(model, batch)
# Average gradients across all processes
for key in grads:
grads[key] = mx.distributed.all_sum(grads[key]) / world.size()
# Update model with averaged gradients
optimizer.update(model, grads)
all_gather
mlx.core.distributed.all_gather(x: array) -> array
Gather arrays from all processes.
Each process receives a concatenation of arrays from all processes along the first dimension.
Parameters:
x (array): Input array to gather
Returns:
- Array with first dimension equal to
x.shape[0] * world.size()
Example:
import mlx.core as mx
world = mx.distributed.init()
# Each process has different data
x = mx.full((2, 3), float(world.rank()))
print(f"Rank {world.rank()} input shape: {x.shape}") # (2, 3)
# Gather from all processes
gathered = mx.distributed.all_gather(x)
print(f"Rank {world.rank()} output shape: {gathered.shape}") # (8, 3) if 4 processes
# gathered contains:
# [[0, 0, 0],
# [0, 0, 0], # from rank 0
# [1, 1, 1],
# [1, 1, 1], # from rank 1
# [2, 2, 2],
# [2, 2, 2], # from rank 2
# [3, 3, 3],
# [3, 3, 3]] # from rank 3
send
mlx.core.distributed.send(x: array, dst: int, tag: int = 0)
Send an array to a specific process.
Parameters:
x (array): Array to send
dst (int): Destination rank
tag (int): Message tag for matching with receive. Default: 0
Example:
import mlx.core as mx
world = mx.distributed.init(backend="mpi") # MPI required for send/recv
if world.rank() == 0:
# Rank 0 sends to rank 1
data = mx.array([1.0, 2.0, 3.0])
mx.distributed.send(data, dst=1)
print("Rank 0 sent data")
Point-to-point communication (send and recv) is not supported by the Ring backend. Use MPI, JACCL, or NCCL for these operations.
recv
mlx.core.distributed.recv(
shape: tuple,
dtype: Dtype,
src: int,
tag: int = 0
) -> array
Receive an array from a specific process.
You must know the shape and dtype of the incoming array.
Parameters:
shape (tuple): Shape of the array to receive
dtype (Dtype): Data type of the array to receive
src (int): Source rank
tag (int): Message tag for matching with send. Default: 0
Returns:
Example:
import mlx.core as mx
world = mx.distributed.init(backend="mpi")
if world.rank() == 1:
# Rank 1 receives from rank 0
data = mx.distributed.recv(
shape=(3,),
dtype=mx.float32,
src=0
)
print(f"Rank 1 received: {data}")
recv_like
mlx.core.distributed.recv_like(
x: array,
src: int,
tag: int = 0
) -> array
Receive an array with the same shape and dtype as the template array.
Parameters:
x (array): Template array (only shape and dtype are used)
src (int): Source rank
tag (int): Message tag. Default: 0
Returns:
Example:
import mlx.core as mx
world = mx.distributed.init(backend="mpi")
if world.rank() == 0:
data = mx.random.normal((10, 20))
mx.distributed.send(data, dst=1)
elif world.rank() == 1:
# Receive with matching shape
template = mx.zeros((10, 20))
data = mx.distributed.recv_like(template, src=0)
print(f"Received shape: {data.shape}")
Backends
Ring Backend
- Always available, no dependencies
- Uses TCP sockets
- Nodes connected in a ring topology
- Best for Ethernet or Thunderbolt connections
- Does not support point-to-point send/recv
world = mx.distributed.init(backend="ring")
JACCL Backend
- Low-latency RDMA over Thunderbolt 5 on macOS 26.2+
- Requires fully connected mesh topology
- Order of magnitude lower latency than Ring
- Best for tensor parallelism
world = mx.distributed.init(backend="jaccl")
Setting up JACCL:
# Auto-configure thunderbolt mesh
mlx.distributed_config --verbose \
--hosts m3-1,m3-2,m3-3,m3-4 \
--over thunderbolt --backend jaccl \
--auto-setup --output hostfile.json
# Launch with JACCL
mlx.launch --backend jaccl --hostfile hostfile.json \
--env MLX_METAL_FAST_SYNCH=1 -- \
python my_script.py
MPI Backend
- Full-featured, mature library
- Supports all operations
- Requires MPI installation
world = mx.distributed.init(backend="mpi")
Installation:
# Via conda (recommended)
conda install conda-forge::openmpi
# Launch with MPI
mlx.launch --backend mpi -n 4 my_script.py
NCCL Backend
- High-performance for CUDA environments
- Default backend for CUDA in
mlx.launch
- Supports multi-GPU and multi-node
world = mx.distributed.init(backend="nccl")
# Launch on 8 GPUs
mlx.launch -n 8 my_script.py
Practical Examples
Data Parallel Training
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# Initialize distributed
world = mx.distributed.init()
# Each process loads different data
train_loader = get_data_loader(
shard_id=world.rank(),
num_shards=world.size()
)
# Training loop
for batch in train_loader:
# Forward pass
loss, grads = loss_fn(model, batch)
# Synchronize gradients
grads = mx.tree_map(
lambda g: mx.distributed.all_sum(g) / world.size(),
grads
)
# Update model
optimizer.update(model, grads)
Model Parallel Inference
import mlx.core as mx
world = mx.distributed.init(backend="jaccl") # Low latency needed
# Split model across devices
if world.rank() == 0:
output = first_half_of_model(input)
mx.distributed.send(output, dst=1)
elif world.rank() == 1:
input = mx.distributed.recv_like(template, src=0)
output = second_half_of_model(input)
Distributed Evaluation
import mlx.core as mx
world = mx.distributed.init()
# Each process evaluates different data
local_correct = evaluate_shard(model, test_data[world.rank()])
# Sum correct predictions across all processes
total_correct = mx.distributed.all_sum(mx.array([local_correct]))
if world.rank() == 0:
accuracy = total_correct[0] / total_samples
print(f"Accuracy: {accuracy:.2%}")
- Batch communication: Combine small arrays into larger ones before calling all_sum or all_gather
- Use JACCL for tensor parallelism: The low latency is critical for frequent small communications
- Test locally first: Run with
mlx.launch -n 2 on a single machine before scaling up
- Enable fast sync: Set
MLX_METAL_FAST_SYNCH=1 when using JACCL
- Overlap computation and communication: Start the next computation before waiting for communication to complete
See Also