Skip to main content
MLX supports writing custom Metal kernels through both the Python and C++ APIs. This allows you to implement highly optimized GPU operations for Apple Silicon.

Quick Start

Here’s a simple custom kernel that computes exp element-wise:
import mlx.core as mx

source = """
    uint elem = thread_position_in_grid.x;
    T tmp = inp[elem];
    out[elem] = metal::exp(tmp);
"""

kernel = mx.fast.metal_kernel(
    name="myexp",
    input_names=["inp"],
    output_names=["out"],
    source=source,
)

def exp_elementwise(a: mx.array):
    outputs = kernel(
        inputs=[a],
        template=[("T", a.dtype)],
        grid=(a.size, 1, 1),
        threadgroup=(256, 1, 1),
        output_shapes=[a.shape],
        output_dtypes=[a.dtype],
    )
    return outputs[0]

# Use it
a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
b = exp_elementwise(a)
assert mx.allclose(b, mx.exp(a))

How It Works

Kernel Source

Only pass the body of the Metal kernel in source. The function signature is generated automatically based on:
  • Input arrays: From inputs parameter
  • Output arrays: From output_dtypes parameter
  • Template parameters: From template parameter
  • Metal attributes: Any Metal attributes used in source
For the example above, the generated signature is:
template <typename T>
[[kernel]] void custom_kernel_myexp(
    const device float16_t* inp [[buffer(0)]],
    device float16_t* out [[buffer(1)]],
    uint3 thread_position_in_grid [[thread_position_in_grid]]
) {
    uint elem = thread_position_in_grid.x;
    T tmp = inp[elem];
    out[elem] = metal::exp(tmp);
}

Grid and Threadgroups

grid and threadgroup map to Metal’s dispatchThreads function:
  • grid: Total number of threads to launch (3D)
  • threadgroup: Size of each threadgroup (3D)
For optimal performance, each threadgroup dimension should be ≤ the corresponding grid dimension.

Template Parameters

Template parameters can be:
  • mx.core.Dtype - Data types (float32, float16, etc.)
  • int - Integer constants
  • bool - Boolean flags
template=[
    ("T", mx.float32),  # Type parameter
    ("N", 256),         # Integer parameter
    ("USE_BIAS", True)  # Boolean parameter
]

Using Shapes and Strides

Row-Contiguous Arrays

By default, ensure_row_contiguous=True copies input arrays to be row-contiguous. This simplifies indexing:
source = """
    uint elem = thread_position_in_grid.x;
    out[elem] = metal::exp(inp[elem]);  // Simple linear indexing
"""

Arbitrary Strides

To avoid copies and support arbitrary strides, set ensure_row_contiguous=False and use MLX indexing utilities:
source = """
    uint elem = thread_position_in_grid.x;
    // elem_to_loc from mlx/backend/metal/kernels/utils.h
    uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
    T tmp = inp[loc];
    out[elem] = metal::exp(tmp);  // Output is always row-contiguous
"""

kernel = mx.fast.metal_kernel(
    name="myexp_strided",
    input_names=["inp"],
    output_names=["out"],
    source=source,
    ensure_row_contiguous=False,
)
MLX automatically provides {name}_shape, {name}_strides, and {name}_ndim for each input array if they appear in source.

Advanced Example: Grid Sample

Here’s a more complex example implementing bilinear grid sampling.

Reference Implementation

First, a reference implementation using standard MLX ops:
def grid_sample_ref(x, grid):
    N, H_in, W_in, _ = x.shape
    ix = ((grid[..., 0] + 1) * W_in - 1) / 2
    iy = ((grid[..., 1] + 1) * H_in - 1) / 2
    
    ix_nw = mx.floor(ix).astype(mx.int32)
    iy_nw = mx.floor(iy).astype(mx.int32)
    
    ix_ne = ix_nw + 1
    iy_ne = iy_nw
    ix_sw = ix_nw
    iy_sw = iy_nw + 1
    ix_se = ix_nw + 1
    iy_se = iy_nw + 1
    
    nw = (ix_se - ix) * (iy_se - iy)
    ne = (ix - ix_sw) * (iy_sw - iy)
    sw = (ix_ne - ix) * (iy - iy_ne)
    se = (ix - ix_nw) * (iy - iy_nw)
    
    # Gather values from corners
    I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
    I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
    I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
    I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
    
    # Apply boundary masks
    mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
    # ... similar for ne, sw, se
    
    I_nw *= mask_nw[..., None]
    # ... similar for others
    
    output = nw[..., None] * I_nw + ne[..., None] * I_ne + \
             sw[..., None] * I_sw + se[..., None] * I_se
    
    return output

Fused Metal Kernel

Now implement as a fused Metal kernel:
source = """
    uint elem = thread_position_in_grid.x;
    int H = x_shape[1];
    int W = x_shape[2];
    int C = x_shape[3];
    int gH = grid_shape[1];
    int gW = grid_shape[2];
    
    int w_stride = C;
    int h_stride = W * w_stride;
    int b_stride = H * h_stride;
    
    uint grid_idx = elem / C * 2;
    float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
    float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
    
    int ix_nw = floor(ix);
    int iy_nw = floor(iy);
    int ix_ne = ix_nw + 1;
    int iy_ne = iy_nw;
    int ix_sw = ix_nw;
    int iy_sw = iy_nw + 1;
    int ix_se = ix_nw + 1;
    int iy_se = iy_nw + 1;
    
    T nw = (ix_se - ix) * (iy_se - iy);
    T ne = (ix - ix_sw) * (iy_sw - iy);
    T sw = (ix_ne - ix) * (iy - iy_ne);
    T se = (ix - ix_nw) * (iy - iy_nw);
    
    int batch_idx = elem / C / gH / gW * b_stride;
    int channel_idx = elem % C;
    int base_idx = batch_idx + channel_idx;
    
    T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
    T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
    T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
    T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
    
    I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
    I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
    I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
    I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
    
    out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
"""

kernel = mx.fast.metal_kernel(
    name="grid_sample",
    input_names=["x", "grid"],
    output_names=["out"],
    source=source,
)

@mx.custom_function
def grid_sample(x, grid):
    B, _, _, C = x.shape
    _, gN, gM, D = grid.shape
    out_shape = (B, gN, gM, C)
    
    outputs = kernel(
        inputs=[x, grid],
        template=[("T", x.dtype)],
        output_shapes=[out_shape],
        output_dtypes=[x.dtype],
        grid=(int(mx.prod(mx.array(out_shape)).item()), 1, 1),
        threadgroup=(256, 1, 1),
    )
    return outputs[0]
Performance: For x.shape = (8, 1024, 1024, 64) and grid.shape = (8, 256, 256, 2) on M1 Max:
  • Reference: 55.7ms
  • Fused kernel: 6.7ms
  • Speedup: 8x

Custom VJP with Atomics

Implement the backward pass using atomic operations:
source = """
    uint elem = thread_position_in_grid.x;
    int H = x_shape[1];
    int W = x_shape[2];
    int C = x_shape[3];
    int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
    
    // ... compute gradients ...
    
    if (channel_idx < C) {
        // Atomically update x_grad
        if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
            int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
            atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
        }
        // ... similar for other corners ...
    }
    
    // Reduce within simdgroup first (faster than pure atomics)
    gix = simd_sum(gix);
    giy = simd_sum(giy);
    
    if (thread_index_in_simdgroup == 0) {
        atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
        atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
    }
"""

kernel = mx.fast.metal_kernel(
    name="grid_sample_grad",
    input_names=["x", "grid", "cotangent"],
    output_names=["x_grad", "grid_grad"],
    source=source,
    atomic_outputs=True,  # Enable atomic operations on outputs
)

@grid_sample.vjp
def grid_sample_vjp(primals, cotangent, _):
    x, grid = primals
    B, _, _, C = x.shape
    _, gN, gM, D = grid.shape
    
    # Pad to simdgroup size to avoid overlap in simd_sum
    simdgroup_size = 32
    C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
    grid_size = B * gN * gM * C_padded
    
    outputs = kernel(
        inputs=[x, grid, cotangent],
        template=[("T", x.dtype)],
        output_shapes=[x.shape, grid.shape],
        output_dtypes=[x.dtype, x.dtype],
        grid=(grid_size, 1, 1),
        threadgroup=(256, 1, 1),
        init_value=0,  # Initialize outputs to 0 before kernel
    )
    return outputs[0], outputs[1]
VJP Performance: For the same input sizes:
  • Reference: 676.4ms
  • Custom kernel: 16.7ms
  • Speedup: 40x

Kernel Features

Initialization

init_value=0  # Initialize all outputs to this value before kernel runs
Useful when the kernel only updates part of the output (e.g., with scatter operations).

Atomic Outputs

atomic_outputs=True  # Make outputs atomic in function signature
Enables Metal atomic operations for thread-safe updates. See Metal Shading Language Specification section 6.15.

Verbose Mode

outputs = kernel(
    ...,
    verbose=True  # Print generated Metal code for debugging
)

Metal Attributes

All Metal attributes from Table 5.8 of the Metal Specification are supported:
  • thread_position_in_grid - Global thread index
  • thread_position_in_threadgroup - Local thread index
  • thread_index_in_simdgroup - Index within SIMD group
  • threads_per_simdgroup - Size of SIMD group
  • threadgroup_position_in_grid - Threadgroup index
Example:
source = """
    uint gid = thread_position_in_grid.x;
    uint lid = thread_position_in_threadgroup.x;
    uint simd_idx = thread_index_in_simdgroup;
    
    // Use simdgroup operations
    float sum = simd_sum(local_value);
"""

Best Practices

Performance Tips

  1. Fuse operations: Combine multiple operations into one kernel
  2. Use simdgroup operations: simd_sum(), simd_max(), etc. are very fast
  3. Minimize atomics: Use simdgroup reductions first, then atomics
  4. Pad to simdgroup size: Avoid false sharing when using simd_sum()
  5. Profile with Xcode: Use Metal GPU capture for detailed profiling

Memory Access

  1. Coalesced reads: Access memory in a pattern that matches thread layout
  2. Bank conflicts: Avoid when using threadgroup memory
  3. Output is contiguous: Output arrays are always row-contiguous

Debugging

  1. Use verbose=True to see generated code
  2. Start with simple kernels and add complexity incrementally
  3. Test against reference implementation
  4. Use Xcode GPU debugger for GPU-side debugging

Utilities

MLX provides utilities in mlx/backend/metal/kernels/utils.h:
// Convert linear index to strided location
uint elem_to_loc(uint elem, const int* shape, const int64_t* strides, int ndim);

// Ceiling division
int ceildiv(int a, int b);
These are automatically included in your kernel source.

Next Steps

C++ Extensions

Build complete C++ extensions with primitives

Operations Reference

Browse the C++ API reference

Resources

Build docs developers (and LLMs) love