Skip to main content

Overview

MLX provides powerful function transformations including automatic differentiation (grad, vjp, jvp), vectorization (vmap), compilation, and custom gradient definitions. These transforms enable efficient computation of derivatives, batch processing, and optimization.

Evaluation

eval

mx.eval(*arrays: array) -> None
mx.eval(arrays: list[array]) -> None
Evaluate arrays synchronously. This function forces immediate computation of the provided arrays, blocking until all computations complete.
arrays
array | list[array]
required
Arrays to evaluate. Can be passed as individual arguments or as a list.
Example:
import mlx.core as mx

x = mx.array([1, 2, 3])
y = x + 1
mx.eval(y)  # Force computation

# Multiple arrays
a = mx.array([1, 2])
b = mx.array([3, 4])
mx.eval(a, b)

async_eval

mx.async_eval(*arrays: array) -> None
mx.async_eval(arrays: list[array]) -> None
Evaluate arrays asynchronously. This function schedules computation of the provided arrays but returns immediately without blocking.
arrays
array | list[array]
required
Arrays to evaluate. Can be passed as individual arguments or as a list.
Example:
x = mx.array([1, 2, 3])
y = x * 2
mx.async_eval(y)  # Schedule computation, don't wait

Automatic Differentiation

grad

mx.grad(
    fun: callable,
    argnums: int | tuple[int] = 0
) -> callable
Returns a function that computes the gradient of fun. The function being differentiated must return a scalar array. The gradient is computed with respect to the arguments specified by argnums.
fun
callable
required
Function to differentiate. Must return a scalar array.
argnums
int | tuple[int]
default:"0"
Index or indices of arguments to differentiate with respect to
Returns: Function that computes gradients Example:
import mlx.core as mx

# Gradient of x^2
def f(x):
    return (x * x).sum()

grad_f = mx.grad(f)
x = mx.array([1.0, 2.0, 3.0])
dfdx = grad_f(x)  # [2.0, 4.0, 6.0]

# Multiple arguments
def g(x, y):
    return (x * y).sum()

grad_g = mx.grad(g, argnums=(0, 1))
dgdx, dgdy = grad_g(mx.array([1.0, 2.0]), mx.array([3.0, 4.0]))

value_and_grad

mx.value_and_grad(
    fun: callable,
    argnums: int | tuple[int] = 0
) -> callable
Returns a function that computes both the value and gradient of fun.
fun
callable
required
Function to differentiate. Must return a scalar array (or tuple where first element is scalar).
argnums
int | tuple[int]
default:"0"
Index or indices of arguments to differentiate with respect to
Returns: Function that returns (value, gradient) or ((value, aux), gradient) if fun returns auxiliary data Example:
def f(x):
    return (x * x).sum()

value_grad_f = mx.value_and_grad(f)
value, grad = value_grad_f(mx.array([1.0, 2.0, 3.0]))
# value: 14.0, grad: [2.0, 4.0, 6.0]

# With auxiliary outputs
def f_with_aux(x):
    loss = (x * x).sum()
    aux = {"intermediate": x * 2}
    return loss, aux

value_grad_f = mx.value_and_grad(f_with_aux)
(loss, aux), grad = value_grad_f(mx.array([1.0, 2.0]))

vjp

mx.vjp(
    fun: callable,
    primals: list[array],
    cotangents: list[array]
) -> tuple[list[array], list[array]]
Compute the vector-Jacobian product (VJP). Computes both the function output and the VJP of the cotangents with the Jacobian of fun evaluated at primals.
fun
callable
required
Function to differentiate. Takes a list of arrays and returns a list of arrays.
primals
list[array]
required
Input arrays at which to evaluate the function
cotangents
list[array]
required
Vectors to multiply with the Jacobian
Returns: Tuple (outputs, vjps) containing function outputs and VJP results Example:
def f(x):
    return 2 * x

out, vjp_out = mx.vjp(f, [mx.array(1.0)], [mx.array(2.0)])
# out: [2.0], vjp_out: [4.0]

jvp

mx.jvp(
    fun: callable,
    primals: list[array],
    tangents: list[array]
) -> tuple[list[array], list[array]]
Compute the Jacobian-vector product (JVP). Computes both the function output and the JVP of the Jacobian of fun evaluated at primals with the tangent vectors.
fun
callable
required
Function to differentiate. Takes a list of arrays and returns a list of arrays.
primals
list[array]
required
Input arrays at which to evaluate the function
tangents
list[array]
required
Tangent vectors to multiply with the Jacobian
Returns: Tuple (outputs, jvps) containing function outputs and JVP results Example:
def f(x, y):
    return x * y

out, jvp_out = mx.jvp(
    f,
    [mx.array(4.0), mx.array(2.0)],
    [mx.array(3.0), mx.array(2.0)]
)
# out: [8.0], jvp_out: [4.0*2.0 + 2.0*3.0] = [14.0]

Vectorization

vmap

mx.vmap(
    fun: callable,
    in_axes: int | tuple = 0,
    out_axes: int | tuple = 0
) -> callable
Automatically vectorize a function over specified axes. Transforms a function that operates on individual examples to one that operates on batches.
fun
callable
required
Function to vectorize
in_axes
int | tuple
default:"0"
Axis or axes to map over for input arguments. Use None to broadcast an argument.
out_axes
int | tuple
default:"0"
Axis or axes for the output
Returns: Vectorized function Example:
import mlx.core as mx

# Vectorize over first dimension
def f(x):
    return x * 2

vf = mx.vmap(f)
x = mx.array([[1, 2], [3, 4], [5, 6]])
result = vf(x)  # Applies f to each row

# Custom axes
def g(x, y):
    return x + y

vg = mx.vmap(g, in_axes=(0, None))  # Batch x, broadcast y
x = mx.array([1, 2, 3])
y = mx.array(10)
result = vg(x, y)  # [11, 12, 13]

Compilation

compile

mx.compile(
    fun: callable,
    inputs: tuple = None,
    outputs: tuple = None
) -> callable
Compile a function for optimized execution. Compiled functions are optimized and can run significantly faster, especially for complex computational graphs.
fun
callable
required
Function to compile
inputs
tuple
default:"None"
Input specifications for compilation
outputs
tuple
default:"None"
Output specifications for compilation
Returns: Compiled function Example:
import mlx.core as mx

def f(x, y):
    return mx.exp(x) + mx.log(y)

compiled_f = mx.compile(f)
result = compiled_f(mx.array([1.0, 2.0]), mx.array([3.0, 4.0]))

disable_compile

mx.disable_compile() -> context_manager
Context manager to temporarily disable compilation. Example:
with mx.disable_compile():
    # Code here won't be compiled
    result = compiled_function(x)

enable_compile

mx.enable_compile() -> context_manager
Context manager to ensure compilation is enabled. Example:
with mx.enable_compile():
    result = function(x)  # Will be compiled

Custom Transformations

custom_function

mx.custom_function(
    fun: callable,
    fun_vjp: callable = None,
    fun_jvp: callable = None,
    fun_vmap: callable = None
) -> callable
Define custom transformations for a function. Allows you to specify custom implementations for VJP, JVP, and vmap transformations.
fun
callable
required
The function to transform
fun_vjp
callable
default:"None"
Custom VJP implementation
fun_jvp
callable
default:"None"
Custom JVP implementation
fun_vmap
callable
default:"None"
Custom vmap implementation
Returns: Function with custom transformations Example:
def my_fun(x):
    return x * 2

def my_vjp(primals, cotangents, outputs):
    # Custom backward pass
    return [c * 2 for c in cotangents]

custom_f = mx.custom_function(my_fun, fun_vjp=my_vjp)

checkpoint

mx.checkpoint(fun: callable) -> callable
Checkpoint the gradient computation. Discards intermediate activations during the forward pass and recomputes them during the backward pass. This trades computation for memory, useful for training large models.
fun
callable
required
Function to checkpoint
Returns: Checkpointed function Example:
def expensive_computation(x):
    # Many intermediate computations
    for _ in range(100):
        x = mx.exp(x)
        x = mx.log(x + 1)
    return x

# Save memory by recomputing in backward pass
checkpointed_f = mx.checkpoint(expensive_computation)
grad_f = mx.grad(checkpointed_f)

Build docs developers (and LLMs) love