Skip to main content

mlx-rs

Unofficial Rust bindings for the MLX framework, providing efficient machine learning primitives for Apple Silicon.

Overview

mlx-rs is a Rust wrapper around Apple’s MLX framework, designed to leverage unified memory architecture and lazy evaluation for high-performance machine learning on Apple Silicon.

Key features

  • Lazy evaluation - Operations build compute graphs that execute only when needed
  • Unified memory - CPU and GPU share memory, no explicit device transfers
  • Type-safe arrays - Strongly-typed n-dimensional arrays with compile-time safety
  • Automatic differentiation - Function transforms for gradient computation
  • Hardware acceleration - Optimized for Apple Silicon Metal GPUs

Array

The core Array type represents n-dimensional tensors.

Construction

Array::from_slice
fn
Create array from slice with explicit shape
use mlx_rs::Array;

let data = vec![1i32, 2, 3, 4, 5, 6];
let arr = Array::from_slice(&data, &[2, 3]);
Array::from_bool
fn
Create scalar array from boolean value
let arr = Array::from_bool(true);
Array::from_int
fn
Create scalar array from i32 value
let arr = Array::from_int(42);
Array::from_f32
fn
Create scalar array from f32 value
let arr = Array::from_f32(3.14);

Array macro

The array! macro provides convenient array construction:
use mlx_rs::{array, Dtype};

let a = array!([1, 2, 3, 4]);
assert_eq!(a.shape(), &[4]);
assert_eq!(a.dtype(), Dtype::Int32);

let b = array!([1.0, 2.0, 3.0, 4.0]);
assert_eq!(b.dtype(), Dtype::Float32);

Properties

shape
&[i32]
Returns the shape of the array
let arr = array!([[1, 2], [3, 4]]);
assert_eq!(arr.shape(), &[2, 2]);
dtype
Dtype
Returns the data type of the array
let arr = array!([1.0, 2.0]);
assert_eq!(arr.dtype(), Dtype::Float32);
size
usize
Returns total number of elements in the array
let arr = array!([[1, 2], [3, 4]]);
assert_eq!(arr.size(), 4);
ndim
usize
Returns number of dimensions
let arr = array!([[1, 2], [3, 4]]);
assert_eq!(arr.ndim(), 2);
is_contiguous
bool
Check if array is contiguous in memory (row-major/C-style)
use mlx_rs::ops::indexing::IndexOp;

let arr = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
assert!(arr.is_contiguous());

let sliced = arr.index((.., ..2));
// May not be contiguous after indexing

Evaluation

Array::eval
fn() -> Result<()>
Evaluate the array, forcing computation of lazy operations
use mlx_rs::{array, transforms::eval};

let a = array!([1, 2, 3, 4]);
let b = array!([1.0, 2.0, 3.0, 4.0]);

let c = &a + &b; // Not evaluated yet
c.eval().unwrap(); // Evaluate now
Array::item
fn<T>() -> T
Extract scalar value from array (evaluates automatically)
let arr = array!(42);
let value: i32 = arr.item();
assert_eq!(value, 42);
Array::as_slice
fn<T>() -> &[T]
Access underlying data as slice (evaluates automatically)
let arr = array!([1, 2, 3, 4]);
let slice: &[i32] = arr.as_slice();
assert_eq!(slice, &[1, 2, 3, 4]);

Operations (ops)

The ops module provides array operations.

Factory functions

ops::zeros
fn<T>(shape: &[i32]) -> Result<Array>
Create array filled with zeros
use mlx_rs::ops::zeros;

let arr = zeros::<f32>(&[2, 3]).unwrap();
ops::ones
fn<T>(shape: &[i32]) -> Result<Array>
Create array filled with ones
use mlx_rs::ops::ones;

let arr = ones::<f32>(&[2, 3]).unwrap();
ops::arange
fn<U, T>(start: U, stop: U, step: U) -> Result<Array>
Create array with evenly spaced values
use mlx_rs::ops::arange;

let arr = arange(0, 10, 2).unwrap(); // [0, 2, 4, 6, 8]
ops::eye
fn<T>(n: i32, m: Option<i32>, k: i32) -> Result<Array>
Create 2D identity matrix
use mlx_rs::ops::eye;

let arr = eye::<f32>(3, None, 0).unwrap();
// [[1, 0, 0],
//  [0, 1, 0],
//  [0, 0, 1]]

Arithmetic operations

Arrays support standard arithmetic operators:
let a = array!([1.0, 2.0, 3.0]);
let b = array!([4.0, 5.0, 6.0]);

let sum = &a + &b;      // Element-wise addition
let diff = &a - &b;     // Element-wise subtraction
let prod = &a * &b;     // Element-wise multiplication
let quot = &a / &b;     // Element-wise division

Reduction operations

ops::sum
fn(array: &Array, axes: Option<&[i32]>) -> Result<Array>
Sum array elements over given axes
use mlx_rs::ops::sum;

let arr = array!([[1, 2], [3, 4]]);
let total = sum(&arr, None).unwrap(); // Sum all: 10
let row_sums = sum(&arr, Some(&[1])).unwrap(); // [3, 7]
ops::mean
fn(array: &Array, axes: Option<&[i32]>) -> Result<Array>
Compute mean over given axes
use mlx_rs::ops::mean;

let arr = array!([1.0, 2.0, 3.0, 4.0]);
let avg = mean(&arr, None).unwrap(); // 2.5
ops::argmax_axis
fn(array: &Array, axis: i32) -> Result<Array>
Indices of maximum values along axis
use mlx_rs::argmax_axis;

let arr = array!([[1, 3, 2], [4, 2, 5]]);
let indices = argmax_axis!(&arr, -1).unwrap(); // [1, 2]

Shape operations

ops::reshape
fn(array: &Array, shape: &[i32]) -> Result<Array>
Reshape array to new dimensions
use mlx_rs::ops::reshape;

let arr = array!([1, 2, 3, 4, 5, 6]);
let reshaped = reshape(&arr, &[2, 3]).unwrap();
ops::transpose
fn(array: &Array, axes: &[i32]) -> Result<Array>
Permute array dimensions
use mlx_rs::ops::transpose;

let arr = array!([[1, 2, 3], [4, 5, 6]]); // Shape: [2, 3]
let transposed = transpose(&arr, &[1, 0]).unwrap(); // Shape: [3, 2]
ops::concatenate
fn(arrays: &[Array], axis: i32) -> Result<Array>
Join arrays along existing axis
use mlx_rs::ops::concatenate_axis;

let a = array!([1, 2]);
let b = array!([3, 4]);
let concat = concatenate_axis(&[a, b], 0).unwrap(); // [1, 2, 3, 4]

Indexing

The ops::indexing module provides array indexing operations:
use mlx_rs::ops::indexing::{IndexOp, NewAxis};

let arr = array!([[1, 2, 3], [4, 5, 6]]);

// Slice first row
let row = arr.index(0);

// Slice with ranges
let sub = arr.index((.., 1..3)); // All rows, columns 1-2

// Add dimension
let expanded = arr.index(NewAxis); // Shape: [1, 2, 3]

Transforms

Function transformations for automatic differentiation and compilation.

Gradient computation

transforms::grad
fn
Compute gradient of function with respect to first argument
use mlx_rs::{Array, error::Exception, transforms::grad};

fn f(args: &[Array]) -> Result<Array, Exception> {
    let x = &args[0];
    x.square() // f(x) = x²
}

let mut grad_fn = grad(f);
let x = Array::from_f32(3.0);
let df_dx = grad_fn(&[x]).unwrap();
// df/dx = 2x = 6.0
assert_eq!(df_dx.item::<f32>(), 6.0);
transforms::value_and_grad
fn
Compute both function value and gradient
use mlx_rs::{Array, error::Exception, transforms::value_and_grad};

fn loss(args: &[Array]) -> Result<Array, Exception> {
    let x = &args[0];
    x.square()
}

let mut vg_fn = value_and_grad(loss);
let x = Array::from_f32(3.0);
let (value, grad) = vg_fn(&[x]).unwrap();

Evaluation

transforms::eval
fn(outputs: impl IntoIterator<Item = &Array>) -> Result<()>
Evaluate multiple arrays in one graph evaluation
use mlx_rs::transforms::eval;

let a = array!([1, 2, 3]);
let b = array!([4, 5, 6]);
let c = &a + &b;
let d = &a * &b;

eval(&[&c, &d]).unwrap(); // Evaluate both together
transforms::eval_params
fn(params: ModuleParamRef) -> Result<()>
Evaluate all parameters of a module
use mlx_rs::transforms::eval_params;

// After updating model parameters
eval_params(model.parameters()).unwrap();

Random

Random number generation operations.
random::normal
fn
Generate array from normal distribution
use mlx_rs::normal;

let arr = normal!(shape = &[100]).unwrap();
// Mean 0, std 1
random::uniform
fn
Generate array from uniform distribution
use mlx_rs::uniform;

let arr = uniform!(low = 0.0, high = 1.0, shape = &[100]).unwrap();
random::categorical
fn
Sample from categorical distribution
use mlx_rs::categorical;

let logits = array!([0.1, 0.5, 0.4]);
let sample = categorical!(&logits).unwrap();

Neural networks (nn)

The nn module provides neural network layers and utilities.

Layers

  • nn::Linear - Fully connected layer
  • nn::Conv1d, nn::Conv2d - Convolution layers
  • nn::LayerNorm - Layer normalization
  • nn::Dropout - Dropout regularization
  • nn::Embedding - Embedding lookup table

Activations

  • nn::relu - ReLU activation
  • nn::gelu - GELU activation
  • nn::silu - SiLU/Swish activation
  • nn::softmax - Softmax function

Rotary position embedding

nn::Rope
struct
Rotary position embedding for transformers
use mlx_rs::nn::{RopeBuilder, Rope};

let rope = RopeBuilder::new(64)
    .traditional(false)
    .base(10000.0)
    .build()
    .unwrap();

Module system

The module module provides traits for neural network modules.
module::Module
trait
Trait for neural network modules with learnable parameters
use mlx_rs::module::Module;

// All nn layers implement Module
let output = model.forward(&input).unwrap();
let params = model.parameters();

Saving and loading

Arrays and models can be saved/loaded:
TypeLoadSave
ArrayArray::load_numpyArray::save_numpy
HashMap<String, Array>Array::load_safetensorsArray::save_safetensors
ModuleModuleParametersExt::load_safetensorsModuleParametersExt::save_safetensors
// Save single array
arr.save_numpy("array.npy").unwrap();

// Load single array
let arr = Array::load_numpy("array.npy").unwrap();

// Save model parameters
model.save_safetensors("model.safetensors").unwrap();

// Load model parameters
model.load_safetensors("model.safetensors").unwrap();

Lazy evaluation

MLX uses lazy evaluation - operations build compute graphs without executing:
use mlx_rs::{array, transforms::eval};

let a = array!([1, 2, 3, 4]);
let b = array!([1.0, 2.0, 3.0, 4.0]);

// No computation happens yet
let c = &a + &b;
let d = &c * 2.0;

// Evaluate when needed
d.eval().unwrap();

// Or evaluation happens automatically when:
println!("{:?}", d);     // Printing
let val: f32 = d.item(); // Getting scalar value
let slice = d.as_slice::<f32>(); // Accessing data

When to evaluate

A natural place to use eval() is at each iteration of an outer loop:
for batch in dataset {
    // Build compute graph
    let (loss, grad) = value_and_grad_fn(&mut model, batch)?;
    
    // Update parameters (still lazy)
    optimizer.update(&mut model, grad)?;
    
    // Evaluate loss and parameters together
    eval_params(model.parameters())?;
}

Unified memory

On Apple Silicon, CPU and GPU share unified memory:
use mlx_rs::normal;

let a = normal!(shape = &[100]).unwrap();
let b = normal!(shape = &[100]).unwrap();

// Both live in unified memory
// Operations specify device at runtime:
let c = mlx_rs::add!(&a, &b, stream = StreamOrDevice::cpu()).unwrap();
let d = mlx_rs::add!(&a, &b, stream = StreamOrDevice::gpu()).unwrap();
No explicit memory transfers needed - arrays are accessible to all devices.

Example: Linear regression

use mlx_rs::{array, ops, transforms, Array};
use mlx_rs::error::Exception;

fn main() -> Result<(), Exception> {
    // Generate synthetic data
    let w_star = mlx_rs::normal!(shape = &[100])?;
    let x = mlx_rs::normal!(shape = &[1000, 100])?;
    let eps = mlx_rs::normal!(shape = &[1000])? * 1e-2;
    let y = x.matmul(&w_star)? + eps;
    
    // Initialize weights
    let w = mlx_rs::normal!(shape = &[100])? * 1e-2;
    
    // Define loss function
    let loss_fn = |inputs: &[Array]| -> Result<Array, Exception> {
        let w = &inputs[0];
        let x = &inputs[1];
        let y = &inputs[2];
        
        let y_pred = x.matmul(w)?;
        let loss = Array::from_f32(0.5) * ops::mean(&ops::square(y_pred - y)?, None)?;
        Ok(loss)
    };
    
    // Train
    let mut grad_fn = transforms::grad(loss_fn);
    let mut inputs = [w, x, y];
    
    for _ in 0..10000 {
        let grad = grad_fn(&inputs)?;
        inputs[0] = &inputs[0] - Array::from_f32(0.01) * grad;
        inputs[0].eval()?;
    }
    
    let loss = loss_fn(&inputs)?;
    println!("Final loss: {:.5}", loss.item::<f32>());
    
    Ok(())
}

Module structure

  • array - Core Array type and operations
  • ops - Array operations (arithmetic, reduction, shape)
  • transforms - Function transforms (grad, eval, compile)
  • random - Random number generation
  • nn - Neural network layers and activations
  • module - Module system for composable networks
  • optimizers - Optimization algorithms (SGD, Adam, etc.)
  • losses - Loss functions
  • linalg - Linear algebra operations
  • fft - Fast Fourier transform
  • fast - Hardware-optimized operations
  • dtype - Data type definitions
  • device - Device management
  • stream - Computation stream management

Build docs developers (and LLMs) love