Overview
MLX provides powerful function transformations including automatic differentiation (grad, vjp, jvp), vectorization (vmap), compilation, and custom gradient definitions. These transforms enable efficient computation of derivatives, batch processing, and optimization.
Evaluation
eval
mx.eval(*arrays: array) -> None
mx.eval(arrays: list[array]) -> None
Evaluate arrays synchronously.
This function forces immediate computation of the provided arrays, blocking until all computations complete.
arrays
array | list[array]
required
Arrays to evaluate. Can be passed as individual arguments or as a list.
Example:
import mlx.core as mx
x = mx.array([1, 2, 3])
y = x + 1
mx.eval(y) # Force computation
# Multiple arrays
a = mx.array([1, 2])
b = mx.array([3, 4])
mx.eval(a, b)
async_eval
mx.async_eval(*arrays: array) -> None
mx.async_eval(arrays: list[array]) -> None
Evaluate arrays asynchronously.
This function schedules computation of the provided arrays but returns immediately without blocking.
arrays
array | list[array]
required
Arrays to evaluate. Can be passed as individual arguments or as a list.
Example:
x = mx.array([1, 2, 3])
y = x * 2
mx.async_eval(y) # Schedule computation, don't wait
Automatic Differentiation
grad
mx.grad(
fun: callable,
argnums: int | tuple[int] = 0
) -> callable
Returns a function that computes the gradient of fun.
The function being differentiated must return a scalar array. The gradient is computed with respect to the arguments specified by argnums.
Function to differentiate. Must return a scalar array.
argnums
int | tuple[int]
default:"0"
Index or indices of arguments to differentiate with respect to
Returns: Function that computes gradients
Example:
import mlx.core as mx
# Gradient of x^2
def f(x):
return (x * x).sum()
grad_f = mx.grad(f)
x = mx.array([1.0, 2.0, 3.0])
dfdx = grad_f(x) # [2.0, 4.0, 6.0]
# Multiple arguments
def g(x, y):
return (x * y).sum()
grad_g = mx.grad(g, argnums=(0, 1))
dgdx, dgdy = grad_g(mx.array([1.0, 2.0]), mx.array([3.0, 4.0]))
value_and_grad
mx.value_and_grad(
fun: callable,
argnums: int | tuple[int] = 0
) -> callable
Returns a function that computes both the value and gradient of fun.
Function to differentiate. Must return a scalar array (or tuple where first element is scalar).
argnums
int | tuple[int]
default:"0"
Index or indices of arguments to differentiate with respect to
Returns: Function that returns (value, gradient) or ((value, aux), gradient) if fun returns auxiliary data
Example:
def f(x):
return (x * x).sum()
value_grad_f = mx.value_and_grad(f)
value, grad = value_grad_f(mx.array([1.0, 2.0, 3.0]))
# value: 14.0, grad: [2.0, 4.0, 6.0]
# With auxiliary outputs
def f_with_aux(x):
loss = (x * x).sum()
aux = {"intermediate": x * 2}
return loss, aux
value_grad_f = mx.value_and_grad(f_with_aux)
(loss, aux), grad = value_grad_f(mx.array([1.0, 2.0]))
vjp
mx.vjp(
fun: callable,
primals: list[array],
cotangents: list[array]
) -> tuple[list[array], list[array]]
Compute the vector-Jacobian product (VJP).
Computes both the function output and the VJP of the cotangents with the Jacobian of fun evaluated at primals.
Function to differentiate. Takes a list of arrays and returns a list of arrays.
Input arrays at which to evaluate the function
Vectors to multiply with the Jacobian
Returns: Tuple (outputs, vjps) containing function outputs and VJP results
Example:
def f(x):
return 2 * x
out, vjp_out = mx.vjp(f, [mx.array(1.0)], [mx.array(2.0)])
# out: [2.0], vjp_out: [4.0]
jvp
mx.jvp(
fun: callable,
primals: list[array],
tangents: list[array]
) -> tuple[list[array], list[array]]
Compute the Jacobian-vector product (JVP).
Computes both the function output and the JVP of the Jacobian of fun evaluated at primals with the tangent vectors.
Function to differentiate. Takes a list of arrays and returns a list of arrays.
Input arrays at which to evaluate the function
Tangent vectors to multiply with the Jacobian
Returns: Tuple (outputs, jvps) containing function outputs and JVP results
Example:
def f(x, y):
return x * y
out, jvp_out = mx.jvp(
f,
[mx.array(4.0), mx.array(2.0)],
[mx.array(3.0), mx.array(2.0)]
)
# out: [8.0], jvp_out: [4.0*2.0 + 2.0*3.0] = [14.0]
Vectorization
vmap
mx.vmap(
fun: callable,
in_axes: int | tuple = 0,
out_axes: int | tuple = 0
) -> callable
Automatically vectorize a function over specified axes.
Transforms a function that operates on individual examples to one that operates on batches.
Axis or axes to map over for input arguments. Use None to broadcast an argument.
Axis or axes for the output
Returns: Vectorized function
Example:
import mlx.core as mx
# Vectorize over first dimension
def f(x):
return x * 2
vf = mx.vmap(f)
x = mx.array([[1, 2], [3, 4], [5, 6]])
result = vf(x) # Applies f to each row
# Custom axes
def g(x, y):
return x + y
vg = mx.vmap(g, in_axes=(0, None)) # Batch x, broadcast y
x = mx.array([1, 2, 3])
y = mx.array(10)
result = vg(x, y) # [11, 12, 13]
Compilation
compile
mx.compile(
fun: callable,
inputs: tuple = None,
outputs: tuple = None
) -> callable
Compile a function for optimized execution.
Compiled functions are optimized and can run significantly faster, especially for complex computational graphs.
Input specifications for compilation
Output specifications for compilation
Returns: Compiled function
Example:
import mlx.core as mx
def f(x, y):
return mx.exp(x) + mx.log(y)
compiled_f = mx.compile(f)
result = compiled_f(mx.array([1.0, 2.0]), mx.array([3.0, 4.0]))
disable_compile
mx.disable_compile() -> context_manager
Context manager to temporarily disable compilation.
Example:
with mx.disable_compile():
# Code here won't be compiled
result = compiled_function(x)
enable_compile
mx.enable_compile() -> context_manager
Context manager to ensure compilation is enabled.
Example:
with mx.enable_compile():
result = function(x) # Will be compiled
custom_function
mx.custom_function(
fun: callable,
fun_vjp: callable = None,
fun_jvp: callable = None,
fun_vmap: callable = None
) -> callable
Define custom transformations for a function.
Allows you to specify custom implementations for VJP, JVP, and vmap transformations.
The function to transform
Custom VJP implementation
Custom JVP implementation
Custom vmap implementation
Returns: Function with custom transformations
Example:
def my_fun(x):
return x * 2
def my_vjp(primals, cotangents, outputs):
# Custom backward pass
return [c * 2 for c in cotangents]
custom_f = mx.custom_function(my_fun, fun_vjp=my_vjp)
checkpoint
mx.checkpoint(fun: callable) -> callable
Checkpoint the gradient computation.
Discards intermediate activations during the forward pass and recomputes them during the backward pass. This trades computation for memory, useful for training large models.
Returns: Checkpointed function
Example:
def expensive_computation(x):
# Many intermediate computations
for _ in range(100):
x = mx.exp(x)
x = mx.log(x + 1)
return x
# Save memory by recomputing in backward pass
checkpointed_f = mx.checkpoint(expensive_computation)
grad_f = mx.grad(checkpointed_f)