Skip to main content
In this example, we’ll explore how tensor parallelism (TP) works in MLX. We’ll start with an overview of the distributed layers in mlx.nn and then show how to apply tensor parallelism to Llama-style transformer models.

Sharded Layers

MLX provides two types of distributed linear layers that work together to implement tensor parallelism.

AllToShardedLinear

mlx.nn.AllToShardedLinear replicates a common input and shards the weight matrix along the output dimension across all devices in the group. The layer produces a sharded output. For example, with input_dims=2 and output_dims=2, a batched input of shape (4, 2), and 2 devices:
  • Each device receives the full input (4, 2)
  • The weight matrix is split along the output dimension
  • Each device computes a partial output (4, 1)
This is also known as column-wise sharding. The layer does not automatically gather outputs from each device - this is intentional and useful for chaining with ShardedToAllLinear.
mlx.nn.QuantizedAllToShardedLinear is the quantized equivalent with frozen parameters.

ShardedToAllLinear

mlx.nn.ShardedToAllLinear expects inputs sharded along the feature dimension and shards the weight matrix along the input dimension. The layer automatically aggregates results using mx.distributed.all_sum(), so all devices have the same result. For example, with input_dims=2 and output_dims=2, a batched input of shape (4, 2), and 2 devices:
  • Each device receives a sharded input (4, 1)
  • The weight matrix is split along the input dimension
  • Each device computes a partial output (4, 2)
  • Results are aggregated across devices with all_sum()
This is also known as row-wise sharding. The layer does not automatically shard inputs - you must provide sharded inputs.
mlx.nn.QuantizedShardedToAllLinear is the quantized equivalent with frozen parameters.

Why These Design Choices?

All-to-sharded and sharded-to-all layers naturally compose because the output of the former is exactly the input needed for the latter. This removes the need for an intermediate gather step, reducing communication overhead. Here’s a simple example showing how they work together:
import mlx.core as mx
import mlx.nn as nn

x = mx.random.normal((4, 2))  # batch size 4, feature size 2

# First layer: all-to-sharded (column-wise)
l1 = nn.AllToShardedLinear(2, 2, bias=False)
l1_out = l1(x)  # (4, 1) output per device

# Second layer: sharded-to-all (row-wise)
l2 = nn.ShardedToAllLinear(2, 2, bias=False)
l2_out = l2(l1_out)  # (4, 2) output, same on all devices
With 2 devices, the data flows without intermediate gathering:
  • Device 0 computes half of l1_out, feeds it to l2, and participates in all_sum()
  • Device 1 computes the other half of l1_out, feeds it to l2, and participates in all_sum()

Shard Utility Functions

shard_linear

mlx.nn.layers.distributed.shard_linear() converts a regular linear layer into a tensor parallel layer:
from mlx.nn.layers.distributed import shard_linear

# Convert to all-to-sharded (column-wise)
linear = nn.Linear(256, 256)
sharded = shard_linear(linear, "all-to-sharded", group=world)

# Convert to sharded-to-all (row-wise)
linear2 = nn.Linear(256, 256)
sharded2 = shard_linear(linear2, "sharded-to-all", group=world)
This function creates a new distributed layer and doesn’t modify the original.

shard_inplace

mlx.nn.layers.distributed.shard_inplace() splits parameters across devices by modifying the layer in-place:
from mlx.nn.layers.distributed import shard_inplace

linear = nn.Linear(256, 256)
shard_inplace(linear, "output", group=world)  # Shard along output dimension
Unlike shard_linear(), this doesn’t add distributed communication - the layer must handle that itself.

LLM Inference with Tensor Parallelism

Let’s apply tensor parallelism to enable inference of larger language models by sharding parameters across multiple devices. We’ll use the Llama inference example as our base.

Initialize Distributed Group

First, initialize the distributed communication group:
import mlx.core as mx

world = mx.distributed.init()
rank = world.rank()

Transformer Architecture

The Llama transformer block has two natural places for tensor parallelism:
  1. Attention block: Q, K, V projections (all-to-sharded) → output projection (sharded-to-all)
  2. FFN block: Gate and up projections (all-to-sharded) → down projection (sharded-to-all)
The intermediate operations (RoPE, softmax, attention, element-wise multiplication) don’t impede tensor parallelism because they either:
  • Are element-wise operations (RoPE, multiplication) that preserve sharding
  • Operate on non-sharded dimensions (softmax on sequence length, attention on head dimensions)

Shard the Attention Block

Convert Q, K, V projections to all-to-sharded layers and the output projection to sharded-to-all:
class Attention(nn.Module):
    # ... (initialization code)

    def shard(self, group: mx.distributed.Group):
        # Adjust head counts for sharding
        self.n_heads = self.n_heads // group.size()
        self.n_kv_heads = self.n_kv_heads // group.size()

        # Shard the projections
        self.wq = nn.layers.distributed.shard_linear(
            self.wq, "all-to-sharded", group=group
        )
        self.wk = nn.layers.distributed.shard_linear(
            self.wk, "all-to-sharded", group=group
        )
        self.wv = nn.layers.distributed.shard_linear(
            self.wv, "all-to-sharded", group=group
        )
        self.wo = nn.layers.distributed.shard_linear(
            self.wo, "sharded-to-all", group=group
        )

Shard the FFN Block

Convert gate (w1) and up (w3) projections to all-to-sharded and down projection (w2) to sharded-to-all:
class FeedForward(nn.Module):
    # ... (initialization code)

    def shard(self, group: mx.distributed.Group):
        self.w1 = nn.layers.distributed.shard_linear(
            self.w1, "all-to-sharded", group=group
        )
        self.w2 = nn.layers.distributed.shard_linear(
            self.w2, "sharded-to-all", group=group
        )
        self.w3 = nn.layers.distributed.shard_linear(
            self.w3, "all-to-sharded", group=group
        )

Apply Sharding to All Layers

In the model loading function, apply sharding to all transformer layers when using multiple devices:
def load_model(model_path):
    # ... (model loading code)
    
    world = mx.distributed.init()
    
    if world.size() > 1:
        # Convert Linear layers to appropriate Sharded Layers
        for layer in model.layers:
            layer.attention.shard(group=world)
            layer.feed_forward.shard(group=world)
    
    return model

Running with Tensor Parallelism

The same inference script works for both single-device and multi-device execution:
# Single device
python llama.py model.npz tokenizer.model "Your prompt here"

# Two devices with tensor parallelism
mlx.launch -n 2 llama.py model.npz tokenizer.model "Your prompt here"

# Four devices with tensor parallelism
mlx.launch -n 4 llama.py model.npz tokenizer.model "Your prompt here"

Benefits of Tensor Parallelism

Larger Models

Fit models that don’t fit in a single device’s memory by sharding parameters

Reduced Communication

Chaining all-to-sharded and sharded-to-all eliminates intermediate gather operations

Same Code

Works for single or multiple devices without code changes

Efficient Inference

Enables fast inference of large models on Apple silicon

Complete Example

The full Llama inference example with tensor parallelism support is available in mlx-examples.

Build docs developers (and LLMs) love