Overview
MLX extensions consist of two main components:- Operations: Front-end functions that operate on arrays
- Primitives: Building blocks that define computation and transformations
Example: Axpby Operation
We’ll implement a custom operation that computesz = alpha * x + beta * y, combining two scaled arrays.
Simple Implementation
First, let’s implement it using existing MLX operations:Creating a Primitive
Define the Primitive Class
A primitive inherits fromPrimitive and implements evaluation and transformation methods:
axpby.h
Implement the Operation
The operation handles type promotion and broadcasting:axpby.cpp
CPU Implementation
CPU Kernel
Implement the element-wise operation on CPU:CPU Evaluation
Dispatch to the correct type:GPU Implementation
Metal Kernel
Write a Metal kernel for GPU execution:axpby.metal
GPU Evaluation
Automatic Differentiation
Forward Mode (JVP)
Reverse Mode (VJP)
Building with CMake
Directory Structure
CMakeLists.txt
Python Bindings
Use nanobind to create Python bindings:bindings.cpp
setup.py
setup.py
Building the Extension
Usage
Now you can use your custom operation:Performance
Custom primitives can significantly improve performance by fusing operations:Next Steps
Metal Kernels
Write custom Metal GPU kernels
C++ Operations
Browse the C++ API reference