MLX provides a comprehensive set of operations for array manipulation, mathematical computations, and linear algebra. Most operations support broadcasting and can be executed on different devices.
Array Creation
arange
Create a 1D array with evenly spaced values.
import mlx.core as mx
# Range from 0 to 10 (exclusive)
a = mx.arange(10)
print(a) # array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)
# Range with start and stop
b = mx.arange(5, 10)
print(b) # array([5, 6, 7, 8, 9], dtype=int32)
# Range with step
c = mx.arange(0, 10, 2)
print(c) # array([0, 2, 4, 6, 8], dtype=int32)
# Floating point range
d = mx.arange(0.0, 1.0, 0.25)
print(d) # array([0, 0.25, 0.5, 0.75], dtype=float32)
End of the range (exclusive).
The data type of the output array. If not specified, inferred from the inputs.
The stream on which to schedule the operation.
1D array of evenly spaced values.
linspace
Create a 1D array with a specified number of evenly spaced values.
import mlx.core as mx
# 5 values from 0 to 1 (inclusive)
a = mx.linspace(0, 1, 5)
print(a) # array([0, 0.25, 0.5, 0.75, 1], dtype=float32)
# 10 values from -1 to 1
b = mx.linspace(-1, 1, 10)
print(b)
End of the range (inclusive).
Number of values to generate.
The data type of the output array.
The stream on which to schedule the operation.
1D array of evenly spaced values.
zeros
Create an array filled with zeros.
import mlx.core as mx
a = mx.zeros(5)
print(a) # array([0, 0, 0, 0, 0], dtype=float32)
b = mx.zeros((2, 3))
print(b)
# array([[0, 0, 0],
# [0, 0, 0]], dtype=float32)
c = mx.zeros((2, 3), dtype=mx.int32)
print(c.dtype) # mlx.core.int32
shape
int | tuple[int, ...]
required
Shape of the array.
The data type of the output array.
The stream on which to schedule the operation.
ones
Create an array filled with ones.
import mlx.core as mx
a = mx.ones(5)
print(a) # array([1, 1, 1, 1, 1], dtype=float32)
b = mx.ones((2, 3), dtype=mx.int32)
print(b)
# array([[1, 1, 1],
# [1, 1, 1]], dtype=int32)
shape
int | tuple[int, ...]
required
Shape of the array.
The data type of the output array.
The stream on which to schedule the operation.
full
Create an array filled with a specified value.
import mlx.core as mx
a = mx.full((2, 3), 7)
print(a)
# array([[7, 7, 7],
# [7, 7, 7]], dtype=int32)
b = mx.full((3,), 3.14, dtype=mx.float32)
print(b) # array([3.14, 3.14, 3.14], dtype=float32)
shape
int | tuple[int, ...]
required
Shape of the array.
The data type of the output array. If not specified, inferred from val.
The stream on which to schedule the operation.
Array filled with the specified value.
eye
Create a 2D identity matrix.
import mlx.core as mx
# 3x3 identity matrix
a = mx.eye(3)
print(a)
# array([[1, 0, 0],
# [0, 1, 0],
# [0, 0, 1]], dtype=float32)
# 3x4 matrix with ones on diagonal
b = mx.eye(3, 4)
print(b)
# array([[1, 0, 0, 0],
# [0, 1, 0, 0],
# [0, 0, 1, 0]], dtype=float32)
# With offset diagonal
c = mx.eye(3, 3, k=1)
print(c)
# array([[0, 1, 0],
# [0, 0, 1],
# [0, 0, 0]], dtype=float32)
Number of columns. If not specified, defaults to n.
Index of the diagonal. 0 is the main diagonal, positive values are above, negative below.
The data type of the output array.
The stream on which to schedule the operation.
2D array with ones on the specified diagonal.
identity
Create a square identity matrix.
import mlx.core as mx
a = mx.identity(4)
print(a)
# array([[1, 0, 0, 0],
# [0, 1, 0, 0],
# [0, 0, 1, 0],
# [0, 0, 0, 1]], dtype=float32)
Size of the square matrix.
The data type of the output array.
The stream on which to schedule the operation.
Square identity matrix of size n x n.
Shape Manipulation
reshape
Reshape an array without changing its data.
import mlx.core as mx
a = mx.arange(12)
print(a.shape) # (12,)
b = mx.reshape(a, (3, 4))
print(b.shape) # (3, 4)
print(b)
# array([[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10, 11]], dtype=int32)
# Use -1 to infer dimension
c = mx.reshape(a, (2, -1))
print(c.shape) # (2, 6)
New shape. One dimension can be -1, which will be inferred.
The stream on which to schedule the operation.
flatten
Flatten an array to 1D.
import mlx.core as mx
a = mx.array([[1, 2, 3], [4, 5, 6]])
b = mx.flatten(a)
print(b) # array([1, 2, 3, 4, 5, 6], dtype=int32)
# Flatten specific axes
c = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(c.shape) # (2, 2, 2)
d = mx.flatten(c, start_axis=1)
print(d.shape) # (2, 4)
The stream on which to schedule the operation.
squeeze
Remove singleton dimensions.
import mlx.core as mx
a = mx.array([[[1, 2, 3]]])
print(a.shape) # (1, 1, 3)
b = mx.squeeze(a)
print(b.shape) # (3,)
# Squeeze specific axis
c = mx.squeeze(a, axis=0)
print(c.shape) # (1, 3)
Axis or axes to squeeze. If not specified, all singleton dimensions are removed.
The stream on which to schedule the operation.
expand_dims
Add singleton dimensions to an array.
import mlx.core as mx
a = mx.array([1, 2, 3])
print(a.shape) # (3,)
b = mx.expand_dims(a, axis=0)
print(b.shape) # (1, 3)
c = mx.expand_dims(a, axis=1)
print(c.shape) # (3, 1)
# Add multiple dimensions
d = mx.expand_dims(a, axis=(0, 2))
print(d.shape) # (1, 3, 1)
axis
int | tuple[int, ...]
required
Position(s) where new axis should be added.
The stream on which to schedule the operation.
Array with added dimensions.
transpose
Permute the dimensions of an array.
import mlx.core as mx
a = mx.array([[1, 2, 3], [4, 5, 6]])
print(a.shape) # (2, 3)
# Reverse all dimensions
b = mx.transpose(a)
print(b.shape) # (3, 2)
print(b)
# array([[1, 4],
# [2, 5],
# [3, 6]], dtype=int32)
# Specific permutation for 3D
c = mx.zeros((2, 3, 4))
d = mx.transpose(c, (2, 0, 1))
print(d.shape) # (4, 2, 3)
Permutation of axes. If not specified, reverses all dimensions.
The stream on which to schedule the operation.
swapaxes
Swap two axes of an array.
import mlx.core as mx
a = mx.zeros((2, 3, 4))
print(a.shape) # (2, 3, 4)
b = mx.swapaxes(a, 0, 2)
print(b.shape) # (4, 3, 2)
The stream on which to schedule the operation.
moveaxis
Move an axis to a new position.
import mlx.core as mx
a = mx.zeros((2, 3, 4, 5))
print(a.shape) # (2, 3, 4, 5)
b = mx.moveaxis(a, 0, -1)
print(b.shape) # (3, 4, 5, 2)
c = mx.moveaxis(a, 2, 0)
print(c.shape) # (4, 2, 3, 5)
Original position of the axis.
Destination position for the axis.
The stream on which to schedule the operation.
Arithmetic Operations
add
Element-wise addition.
import mlx.core as mx
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
c = mx.add(a, b)
print(c) # array([5, 7, 9], dtype=int32)
# Operator overload
d = a + b
print(d) # array([5, 7, 9], dtype=int32)
# Broadcasting
e = mx.add(a, 10)
print(e) # array([11, 12, 13], dtype=int32)
The stream on which to schedule the operation.
subtract
Element-wise subtraction.
import mlx.core as mx
a = mx.array([5, 7, 9])
b = mx.array([1, 2, 3])
c = mx.subtract(a, b)
print(c) # array([4, 5, 6], dtype=int32)
# Operator overload
d = a - b
print(d) # array([4, 5, 6], dtype=int32)
The stream on which to schedule the operation.
multiply
Element-wise multiplication.
import mlx.core as mx
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
c = mx.multiply(a, b)
print(c) # array([4, 10, 18], dtype=int32)
# Operator overload
d = a * b
print(d) # array([4, 10, 18], dtype=int32)
The stream on which to schedule the operation.
divide
Element-wise division.
import mlx.core as mx
a = mx.array([10, 20, 30], dtype=mx.float32)
b = mx.array([2, 4, 5], dtype=mx.float32)
c = mx.divide(a, b)
print(c) # array([5, 5, 6], dtype=float32)
# Operator overload
d = a / b
print(d) # array([5, 5, 6], dtype=float32)
First input array (numerator).
Second input array (denominator).
The stream on which to schedule the operation.
floor_divide
Element-wise floor division (integer division).
import mlx.core as mx
a = mx.array([10, 21, 32])
b = mx.array([3, 4, 5])
c = mx.floor_divide(a, b)
print(c) # array([3, 5, 6], dtype=int32)
The stream on which to schedule the operation.
Element-wise floor quotient.
remainder
Element-wise remainder of division.
import mlx.core as mx
a = mx.array([10, 21, 32])
b = mx.array([3, 4, 5])
c = mx.remainder(a, b)
print(c) # array([1, 1, 2], dtype=int32)
# Operator overload
d = a % b
print(d) # array([1, 1, 2], dtype=int32)
The stream on which to schedule the operation.
power
Element-wise exponentiation.
import mlx.core as mx
a = mx.array([2, 3, 4])
b = mx.array([2, 2, 2])
c = mx.power(a, b)
print(c) # array([4, 9, 16], dtype=int32)
# Broadcasting
d = mx.power(a, 2)
print(d) # array([4, 9, 16], dtype=int32)
The stream on which to schedule the operation.
Mathematical Functions
abs
Element-wise absolute value.
import mlx.core as mx
a = mx.array([-1, -2, 3, -4])
b = mx.abs(a)
print(b) # array([1, 2, 3, 4], dtype=int32)
The stream on which to schedule the operation.
sqrt
Element-wise square root.
import mlx.core as mx
a = mx.array([1.0, 4.0, 9.0, 16.0])
b = mx.sqrt(a)
print(b) # array([1, 2, 3, 4], dtype=float32)
The stream on which to schedule the operation.
rsqrt
Element-wise reciprocal square root (1/sqrt(x)).
import mlx.core as mx
a = mx.array([1.0, 4.0, 9.0])
b = mx.rsqrt(a)
print(b) # array([1, 0.5, 0.333...], dtype=float32)
The stream on which to schedule the operation.
exp
Element-wise exponential (e^x).
import mlx.core as mx
a = mx.array([0.0, 1.0, 2.0])
b = mx.exp(a)
print(b) # array([1, 2.71828, 7.38906], dtype=float32)
The stream on which to schedule the operation.
log
Element-wise natural logarithm.
import mlx.core as mx
a = mx.array([1.0, 2.71828, 7.38906])
b = mx.log(a)
print(b) # array([0, 1, 2], dtype=float32)
The stream on which to schedule the operation.
log2
Element-wise base-2 logarithm.
import mlx.core as mx
a = mx.array([1.0, 2.0, 4.0, 8.0])
b = mx.log2(a)
print(b) # array([0, 1, 2, 3], dtype=float32)
The stream on which to schedule the operation.
log10
Element-wise base-10 logarithm.
import mlx.core as mx
a = mx.array([1.0, 10.0, 100.0, 1000.0])
b = mx.log10(a)
print(b) # array([0, 1, 2, 3], dtype=float32)
The stream on which to schedule the operation.
Trigonometric Functions
sin
Element-wise sine.
import mlx.core as mx
import math
a = mx.array([0.0, math.pi/2, math.pi])
b = mx.sin(a)
print(b) # array([0, 1, 0], dtype=float32)
The stream on which to schedule the operation.
cos
Element-wise cosine.
import mlx.core as mx
import math
a = mx.array([0.0, math.pi/2, math.pi])
b = mx.cos(a)
print(b) # array([1, 0, -1], dtype=float32)
The stream on which to schedule the operation.
tan
Element-wise tangent.
import mlx.core as mx
import math
a = mx.array([0.0, math.pi/4])
b = mx.tan(a)
print(b) # array([0, 1], dtype=float32)
The stream on which to schedule the operation.
Tangents of input values.
Linear Algebra
matmul
Matrix multiplication.
import mlx.core as mx
# Matrix-matrix multiplication
a = mx.array([[1, 2], [3, 4]])
b = mx.array([[5, 6], [7, 8]])
c = mx.matmul(a, b)
print(c)
# array([[19, 22],
# [43, 50]], dtype=int32)
# Matrix-vector multiplication
d = mx.array([1, 2])
e = mx.matmul(a, d)
print(e) # array([5, 11], dtype=int32)
The stream on which to schedule the operation.
Matrix product of a and b.
outer
Outer product of two vectors.
import mlx.core as mx
a = mx.array([1, 2, 3])
b = mx.array([4, 5])
c = mx.outer(a, b)
print(c)
# array([[4, 5],
# [8, 10],
# [12, 15]], dtype=int32)
The stream on which to schedule the operation.
inner
Inner product of two vectors.
import mlx.core as mx
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
c = mx.inner(a, b)
print(c) # array(32, dtype=int32) # 1*4 + 2*5 + 3*6
The stream on which to schedule the operation.
Comparison Operations
equal
Element-wise equality comparison.
import mlx.core as mx
a = mx.array([1, 2, 3])
b = mx.array([1, 0, 3])
c = mx.equal(a, b)
print(c) # array([True, False, True], dtype=bool)
# Operator overload
d = a == b
print(d) # array([True, False, True], dtype=bool)
The stream on which to schedule the operation.
Boolean array of comparison results.
greater
Element-wise greater-than comparison.
import mlx.core as mx
a = mx.array([1, 3, 5])
b = mx.array([2, 2, 4])
c = mx.greater(a, b)
print(c) # array([False, True, True], dtype=bool)
# Operator overload
d = a > b
print(d) # array([False, True, True], dtype=bool)
The stream on which to schedule the operation.
Boolean array of comparison results.
less
Element-wise less-than comparison.
import mlx.core as mx
a = mx.array([1, 3, 5])
b = mx.array([2, 2, 4])
c = mx.less(a, b)
print(c) # array([True, False, False], dtype=bool)
# Operator overload
d = a < b
print(d) # array([True, False, False], dtype=bool)
The stream on which to schedule the operation.
Boolean array of comparison results.
allclose
Test if two arrays are element-wise equal within a tolerance.
import mlx.core as mx
a = mx.array([1.0, 2.0, 3.0])
b = mx.array([1.0001, 2.0001, 3.0001])
# Within default tolerance
print(mx.allclose(a, b)) # array(True, dtype=bool)
# Strict tolerance
print(mx.allclose(a, b, rtol=1e-6, atol=1e-6)) # array(False, dtype=bool)
If True, NaNs are considered equal.
The stream on which to schedule the operation.
Boolean scalar indicating if arrays are close.
Concatenation and Stacking
concatenate
Concatenate arrays along an existing axis.
import mlx.core as mx
a = mx.array([[1, 2], [3, 4]])
b = mx.array([[5, 6]])
# Concatenate along axis 0
c = mx.concatenate([a, b], axis=0)
print(c)
# array([[1, 2],
# [3, 4],
# [5, 6]], dtype=int32)
# Concatenate along axis 1
d = mx.array([[5], [6]])
e = mx.concatenate([a, d], axis=1)
print(e)
# array([[1, 2, 5],
# [3, 4, 6]], dtype=int32)
List of arrays to concatenate.
Axis along which to concatenate.
The stream on which to schedule the operation.
stack
Stack arrays along a new axis.
import mlx.core as mx
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
# Stack along new axis 0
c = mx.stack([a, b], axis=0)
print(c)
# array([[1, 2, 3],
# [4, 5, 6]], dtype=int32)
print(c.shape) # (2, 3)
# Stack along new axis 1
d = mx.stack([a, b], axis=1)
print(d)
# array([[1, 4],
# [2, 5],
# [3, 6]], dtype=int32)
print(d.shape) # (3, 2)
Axis along which to stack.
The stream on which to schedule the operation.
Stacked array with one additional dimension.
Indexing and Slicing
take
Take elements from an array along an axis.
import mlx.core as mx
a = mx.array([10, 20, 30, 40, 50])
indices = mx.array([0, 2, 4])
b = mx.take(a, indices)
print(b) # array([10, 30, 50], dtype=int32)
# 2D indexing
c = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
d = mx.take(c, mx.array([0, 2]), axis=0)
print(d)
# array([[1, 2, 3],
# [7, 8, 9]], dtype=int32)
Indices of elements to take.
Axis along which to take elements. If not specified, the array is flattened.
The stream on which to schedule the operation.
Array with selected elements.
where
Select elements from x or y depending on condition.
import mlx.core as mx
condition = mx.array([True, False, True, False])
x = mx.array([1, 2, 3, 4])
y = mx.array([10, 20, 30, 40])
result = mx.where(condition, x, y)
print(result) # array([1, 20, 3, 40], dtype=int32)
# Use with comparisons
a = mx.array([1, 2, 3, 4, 5])
b = mx.where(a > 2, a, 0)
print(b) # array([0, 0, 3, 4, 5], dtype=int32)
Values to select where condition is True.
Values to select where condition is False.
The stream on which to schedule the operation.
Array with elements from x where condition is True, otherwise from y.