mlx-rs
Unofficial Rust bindings for the MLX framework, providing efficient machine learning primitives for Apple Silicon.Overview
mlx-rs is a Rust wrapper around Apple’s MLX framework, designed to leverage unified memory architecture and lazy evaluation for high-performance machine learning on Apple Silicon.Key features
- Lazy evaluation - Operations build compute graphs that execute only when needed
- Unified memory - CPU and GPU share memory, no explicit device transfers
- Type-safe arrays - Strongly-typed n-dimensional arrays with compile-time safety
- Automatic differentiation - Function transforms for gradient computation
- Hardware acceleration - Optimized for Apple Silicon Metal GPUs
Array
The coreArray type represents n-dimensional tensors.
Construction
Create array from slice with explicit shape
Create scalar array from boolean value
Create scalar array from i32 value
Create scalar array from f32 value
Array macro
Thearray! macro provides convenient array construction:
Properties
Returns the shape of the array
Returns the data type of the array
Returns total number of elements in the array
Returns number of dimensions
Check if array is contiguous in memory (row-major/C-style)
Evaluation
Evaluate the array, forcing computation of lazy operations
Extract scalar value from array (evaluates automatically)
Access underlying data as slice (evaluates automatically)
Operations (ops)
Theops module provides array operations.
Factory functions
Create array filled with zeros
Create array filled with ones
Create array with evenly spaced values
Create 2D identity matrix
Arithmetic operations
Arrays support standard arithmetic operators:Reduction operations
Sum array elements over given axes
Compute mean over given axes
Indices of maximum values along axis
Shape operations
Reshape array to new dimensions
Permute array dimensions
Join arrays along existing axis
Indexing
Theops::indexing module provides array indexing operations:
Transforms
Function transformations for automatic differentiation and compilation.Gradient computation
Compute gradient of function with respect to first argument
Compute both function value and gradient
Evaluation
Evaluate multiple arrays in one graph evaluation
Evaluate all parameters of a module
Random
Random number generation operations.Generate array from normal distribution
Generate array from uniform distribution
Sample from categorical distribution
Neural networks (nn)
Thenn module provides neural network layers and utilities.
Layers
nn::Linear- Fully connected layernn::Conv1d,nn::Conv2d- Convolution layersnn::LayerNorm- Layer normalizationnn::Dropout- Dropout regularizationnn::Embedding- Embedding lookup table
Activations
nn::relu- ReLU activationnn::gelu- GELU activationnn::silu- SiLU/Swish activationnn::softmax- Softmax function
Rotary position embedding
Rotary position embedding for transformers
Module system
Themodule module provides traits for neural network modules.
Trait for neural network modules with learnable parameters
Saving and loading
Arrays and models can be saved/loaded:| Type | Load | Save |
|---|---|---|
Array | Array::load_numpy | Array::save_numpy |
HashMap<String, Array> | Array::load_safetensors | Array::save_safetensors |
Module | ModuleParametersExt::load_safetensors | ModuleParametersExt::save_safetensors |
Lazy evaluation
MLX uses lazy evaluation - operations build compute graphs without executing:When to evaluate
A natural place to useeval() is at each iteration of an outer loop:
Unified memory
On Apple Silicon, CPU and GPU share unified memory:Example: Linear regression
Module structure
array- Core Array type and operationsops- Array operations (arithmetic, reduction, shape)transforms- Function transforms (grad, eval, compile)random- Random number generationnn- Neural network layers and activationsmodule- Module system for composable networksoptimizers- Optimization algorithms (SGD, Adam, etc.)losses- Loss functionslinalg- Linear algebra operationsfft- Fast Fourier transformfast- Hardware-optimized operationsdtype- Data type definitionsdevice- Device managementstream- Computation stream management