mlx.nn and then show how to apply tensor parallelism to Llama-style transformer models.
Sharded Layers
MLX provides two types of distributed linear layers that work together to implement tensor parallelism.AllToShardedLinear
mlx.nn.AllToShardedLinear replicates a common input and shards the weight matrix along the output dimension across all devices in the group. The layer produces a sharded output.
For example, with input_dims=2 and output_dims=2, a batched input of shape (4, 2), and 2 devices:
- Each device receives the full input
(4, 2) - The weight matrix is split along the output dimension
- Each device computes a partial output
(4, 1)
ShardedToAllLinear.
mlx.nn.QuantizedAllToShardedLinear is the quantized equivalent with frozen parameters.ShardedToAllLinear
mlx.nn.ShardedToAllLinear expects inputs sharded along the feature dimension and shards the weight matrix along the input dimension. The layer automatically aggregates results using mx.distributed.all_sum(), so all devices have the same result.
For example, with input_dims=2 and output_dims=2, a batched input of shape (4, 2), and 2 devices:
- Each device receives a sharded input
(4, 1) - The weight matrix is split along the input dimension
- Each device computes a partial output
(4, 2) - Results are aggregated across devices with
all_sum()
mlx.nn.QuantizedShardedToAllLinear is the quantized equivalent with frozen parameters.Why These Design Choices?
All-to-sharded and sharded-to-all layers naturally compose because the output of the former is exactly the input needed for the latter. This removes the need for an intermediate gather step, reducing communication overhead. Here’s a simple example showing how they work together:- Device 0 computes half of
l1_out, feeds it tol2, and participates inall_sum() - Device 1 computes the other half of
l1_out, feeds it tol2, and participates inall_sum()
Shard Utility Functions
shard_linear
mlx.nn.layers.distributed.shard_linear() converts a regular linear layer into a tensor parallel layer:
shard_inplace
mlx.nn.layers.distributed.shard_inplace() splits parameters across devices by modifying the layer in-place:
shard_linear(), this doesn’t add distributed communication - the layer must handle that itself.
LLM Inference with Tensor Parallelism
Let’s apply tensor parallelism to enable inference of larger language models by sharding parameters across multiple devices. We’ll use the Llama inference example as our base.Initialize Distributed Group
First, initialize the distributed communication group:Transformer Architecture
The Llama transformer block has two natural places for tensor parallelism:- Attention block: Q, K, V projections (all-to-sharded) → output projection (sharded-to-all)
- FFN block: Gate and up projections (all-to-sharded) → down projection (sharded-to-all)
- Are element-wise operations (RoPE, multiplication) that preserve sharding
- Operate on non-sharded dimensions (softmax on sequence length, attention on head dimensions)
Shard the Attention Block
Convert Q, K, V projections to all-to-sharded layers and the output projection to sharded-to-all:Shard the FFN Block
Convert gate (w1) and up (w3) projections to all-to-sharded and down projection (w2) to sharded-to-all:Apply Sharding to All Layers
In the model loading function, apply sharding to all transformer layers when using multiple devices:Running with Tensor Parallelism
The same inference script works for both single-device and multi-device execution:Benefits of Tensor Parallelism
Larger Models
Fit models that don’t fit in a single device’s memory by sharding parameters
Reduced Communication
Chaining all-to-sharded and sharded-to-all eliminates intermediate gather operations
Same Code
Works for single or multiple devices without code changes
Efficient Inference
Enables fast inference of large models on Apple silicon