Skip to main content

Overview

Random sampling functions in MLX use an implicit global PRNG state by default. However, all functions take an optional key keyword argument for when more fine-grained control or explicit state management is needed. MLX follows JAX’s PRNG design using a splittable version of Threefry, which is a counter-based PRNG.

Basic Usage

Generate random numbers using the implicit global state:
import mlx.core as mx

for _ in range(3):
    print(mx.random.uniform())
Or use explicit key management:
import mlx.core as mx

key = mx.random.key(0)
for _ in range(3):
    print(mx.random.uniform(key=key))  # Same number each time

Key Management

key

mx.random.key(seed: int) -> array
Get a PRNG key from a seed.
seed
int
required
Seed value for the PRNG
Returns: Array representing the PRNG key Example:
key = mx.random.key(0)

seed

mx.random.seed(seed: int) -> None
Seed the default global PRNG state.
seed
int
required
Seed value for the global PRNG
Example:
mx.random.seed(42)

split

mx.random.split(key: array, num: int = 2, stream: StreamOrDevice = None) -> array
Split the RNG key into multiple keys.
key
array
required
The PRNG key to split
num
int
default:"2"
Number of keys to generate. If num=2, returns a tuple of two keys. Otherwise returns an array of shape (num, 2).
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Tuple of two arrays (if num=2) or array of shape (num, 2) Example:
key = mx.random.key(0)
k1, k2 = mx.random.split(key)
keys = mx.random.split(key, 10)  # Shape: (10, 2)

Continuous Distributions

uniform

mx.random.uniform(
    low: float | array = 0.0,
    high: float | array = 1.0,
    shape: tuple = (),
    dtype: Dtype = float32,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Generate uniformly distributed random numbers.
low
float | array
default:"0.0"
Lower bound of the distribution
high
float | array
default:"1.0"
Upper bound of the distribution
shape
tuple
default:"()"
Shape of the output array
dtype
Dtype
default:"float32"
Data type of the output
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of uniformly distributed random values Example:
# Scalar value between 0 and 1
x = mx.random.uniform()

# Array of shape (2, 3)
x = mx.random.uniform(shape=(2, 3))

# Values between -1 and 5
x = mx.random.uniform(low=-1, high=5, shape=(1000,))

normal

mx.random.normal(
    loc: float = 0.0,
    scale: float = 1.0,
    shape: tuple = (),
    dtype: Dtype = float32,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Generate samples from a normal (Gaussian) distribution.
loc
float
default:"0.0"
Mean of the distribution
scale
float
default:"1.0"
Standard deviation of the distribution
shape
tuple
default:"()"
Shape of the output array
dtype
Dtype
default:"float32"
Data type of the output
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of normally distributed random values Example:
# Standard normal
x = mx.random.normal(shape=(2, 3))

# Mean 1.0, std 2.0
x = mx.random.normal(loc=1.0, scale=2.0, shape=(100,))

multivariate_normal

mx.random.multivariate_normal(
    mean: array,
    cov: array,
    shape: tuple = (),
    dtype: Dtype = float32,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Generate samples from a multivariate normal distribution.
mean
array
required
Mean vector of the distribution
cov
array
required
Covariance matrix of the distribution
shape
tuple
default:"()"
Shape of the output array
dtype
Dtype
default:"float32"
Data type of the output (must be float32)
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the multivariate normal distribution Example:
mean = mx.array([0, 0])
cov = mx.array([[1, 0], [0, 1]])
x = mx.random.multivariate_normal(mean, cov)

laplace

mx.random.laplace(
    loc: float = 0.0,
    scale: float = 1.0,
    shape: tuple = (),
    dtype: Dtype = float32,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Generate samples from the Laplace distribution.
loc
float
default:"0.0"
Location parameter (mean) of the distribution
scale
float
default:"1.0"
Scale parameter of the distribution
shape
tuple
default:"()"
Shape of the output array
dtype
Dtype
default:"float32"
Data type of the output
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the Laplace distribution Example:
x = mx.random.laplace(shape=(100,))

truncated_normal

mx.random.truncated_normal(
    lower: array | float,
    upper: array | float,
    shape: tuple = None,
    dtype: Dtype = float32,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Generate samples from a truncated normal distribution.
lower
array | float
required
Lower bound for truncation
upper
array | float
required
Upper bound for truncation
shape
tuple
default:"None"
Shape of the output array. If None, inferred from lower/upper.
dtype
Dtype
default:"float32"
Data type of the output
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the truncated normal distribution Example:
x = mx.random.truncated_normal(lower=-2, upper=2, shape=(100,))

gumbel

mx.random.gumbel(
    shape: tuple = (),
    dtype: Dtype = float32,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Generate samples from the Gumbel distribution.
shape
tuple
default:"()"
Shape of the output array
dtype
Dtype
default:"float32"
Data type of the output
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of samples from the Gumbel distribution Example:
x = mx.random.gumbel(shape=(100,))

Discrete Distributions

randint

mx.random.randint(
    low: int | array,
    high: int | array,
    shape: tuple = (),
    dtype: Dtype = int32,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Generate random integers uniformly from the range [low, high).
low
int | array
required
Lower bound (inclusive)
high
int | array
required
Upper bound (exclusive)
shape
tuple
default:"()"
Shape of the output array
dtype
Dtype
default:"int32"
Data type of the output
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of random integers Example:
# Integers from 0 to 9
x = mx.random.randint(0, 10, shape=(100,))

bernoulli

mx.random.bernoulli(
    p: float | array = 0.5,
    shape: tuple = None,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Generate binary random variables with probability p of being True.
p
float | array
default:"0.5"
Probability of generating True (1)
shape
tuple
default:"None"
Shape of the output array. If None, inferred from p.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of boolean values Example:
# Coin flips
x = mx.random.bernoulli(p=0.5, shape=(100,))

# 30% chance of True
x = mx.random.bernoulli(p=0.3, shape=(100,))

categorical

mx.random.categorical(
    logits: array,
    axis: int = -1,
    shape: tuple = None,
    num_samples: int = None,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Sample from a categorical distribution.
logits
array
required
Unnormalized log probabilities for each category
axis
int
default:"-1"
Axis along which to sample
shape
tuple
default:"None"
Shape of the output array
num_samples
int
default:"None"
Number of samples to draw. Alternative to specifying shape.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Array of categorical samples (indices) Example:
# Sample from 3 categories with equal probability
logits = mx.array([0.0, 0.0, 0.0])
x = mx.random.categorical(logits, num_samples=100)

Permutations

permutation

mx.random.permutation(
    x: array | int,
    axis: int = 0,
    key: array | None = None,
    stream: StreamOrDevice = None
) -> array
Randomly permute elements.
x
array | int
required
If an array, permute its elements along the given axis. If an integer, equivalent to permutation(mx.arange(x)).
axis
int
default:"0"
Axis along which to permute. Only used if x is an array.
key
array | None
default:"None"
PRNG key. If None, uses the global state.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Permuted array Example:
# Random permutation of 0 to 9
x = mx.random.permutation(10)

# Shuffle array rows
arr = mx.array([[1, 2], [3, 4], [5, 6]])
shuffled = mx.random.permutation(arr, axis=0)

Build docs developers (and LLMs) love