Skip to main content

Overview

The mlx.core.fft module provides functions for computing discrete Fourier transforms. These functions support 1D, 2D, and n-dimensional transforms for both complex and real-valued inputs.

One-Dimensional Transforms

fft

mx.fft.fft(
    a: array,
    n: int = None,
    axis: int = -1,
    stream: StreamOrDevice = None
) -> array
Compute the one-dimensional discrete Fourier Transform.
a
array
required
Input array (can be complex or real)
n
int
default:"None"
Length of the transformed axis. If n is smaller than the input, the input is cropped. If larger, the input is zero-padded. If None, uses the length of the input along the specified axis.
axis
int
default:"-1"
Axis over which to compute the FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the FFT Example:
import mlx.core as mx

# FFT of a real signal
x = mx.array([1.0, 2.0, 3.0, 4.0])
X = mx.fft.fft(x)

# With padding
X = mx.fft.fft(x, n=8)

ifft

mx.fft.ifft(
    a: array,
    n: int = None,
    axis: int = -1,
    stream: StreamOrDevice = None
) -> array
Compute the one-dimensional inverse discrete Fourier Transform.
a
array
required
Input array (typically complex-valued)
n
int
default:"None"
Length of the transformed axis
axis
int
default:"-1"
Axis over which to compute the inverse FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the inverse FFT Example:
X = mx.fft.fft(x)
x_reconstructed = mx.fft.ifft(X)

rfft

mx.fft.rfft(
    a: array,
    n: int = None,
    axis: int = -1,
    stream: StreamOrDevice = None
) -> array
Compute the one-dimensional FFT of a real-valued input. This function is more efficient than fft for real inputs, as it only computes the positive frequency components. The output contains n//2 + 1 complex values.
a
array
required
Real-valued input array
n
int
default:"None"
Length of the transformed axis
axis
int
default:"-1"
Axis over which to compute the FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array of length n//2 + 1 Example:
x = mx.array([1.0, 2.0, 3.0, 4.0])
X = mx.fft.rfft(x)  # Output shape: (3,)

irfft

mx.fft.irfft(
    a: array,
    n: int = None,
    axis: int = -1,
    stream: StreamOrDevice = None
) -> array
Compute the inverse of rfft.
a
array
required
Complex-valued input array (output of rfft)
n
int
default:"None"
Length of the output. If not specified, computed as 2 * (m - 1) where m is the length of the input along the transformed axis.
axis
int
default:"-1"
Axis over which to compute the inverse FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Real-valued array Example:
X = mx.fft.rfft(x)
x_reconstructed = mx.fft.irfft(X)

Two-Dimensional Transforms

fft2

mx.fft.fft2(
    a: array,
    s: tuple = None,
    axes: tuple = (-2, -1),
    stream: StreamOrDevice = None
) -> array
Compute the two-dimensional discrete Fourier Transform.
a
array
required
Input array
s
tuple
default:"None"
Shape of the output along the transformed axes. If not specified, uses the shape of the input.
axes
tuple
default:"(-2, -1)"
Axes over which to compute the FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the 2D FFT Example:
import mlx.core as mx

x = mx.random.uniform(shape=(8, 8))
X = mx.fft.fft2(x)

ifft2

mx.fft.ifft2(
    a: array,
    s: tuple = None,
    axes: tuple = (-2, -1),
    stream: StreamOrDevice = None
) -> array
Compute the two-dimensional inverse discrete Fourier Transform.
a
array
required
Input array (typically complex-valued)
s
tuple
default:"None"
Shape of the output along the transformed axes
axes
tuple
default:"(-2, -1)"
Axes over which to compute the inverse FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the 2D inverse FFT Example:
X = mx.fft.fft2(x)
x_reconstructed = mx.fft.ifft2(X)

rfft2

mx.fft.rfft2(
    a: array,
    s: tuple = None,
    axes: tuple = (-2, -1),
    stream: StreamOrDevice = None
) -> array
Compute the two-dimensional FFT of a real-valued input.
a
array
required
Real-valued input array
s
tuple
default:"None"
Shape of the output along the transformed axes
axes
tuple
default:"(-2, -1)"
Axes over which to compute the FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array Example:
x = mx.random.uniform(shape=(8, 8))
X = mx.fft.rfft2(x)

irfft2

mx.fft.irfft2(
    a: array,
    s: tuple = None,
    axes: tuple = (-2, -1),
    stream: StreamOrDevice = None
) -> array
Compute the inverse of rfft2.
a
array
required
Complex-valued input array (output of rfft2)
s
tuple
default:"None"
Shape of the output along the transformed axes
axes
tuple
default:"(-2, -1)"
Axes over which to compute the inverse FFT
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Real-valued array Example:
X = mx.fft.rfft2(x)
x_reconstructed = mx.fft.irfft2(X)

N-Dimensional Transforms

fftn

mx.fft.fftn(
    a: array,
    s: tuple = None,
    axes: tuple = None,
    stream: StreamOrDevice = None
) -> array
Compute the n-dimensional discrete Fourier Transform.
a
array
required
Input array
s
tuple
default:"None"
Shape of the output along the transformed axes. If not specified, uses the shape of the input.
axes
tuple
default:"None"
Axes over which to compute the FFT. If None, transforms over all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the n-D FFT Example:
x = mx.random.uniform(shape=(4, 4, 4))
X = mx.fft.fftn(x)

# Transform specific axes
X = mx.fft.fftn(x, axes=(0, 2))

ifftn

mx.fft.ifftn(
    a: array,
    s: tuple = None,
    axes: tuple = None,
    stream: StreamOrDevice = None
) -> array
Compute the n-dimensional inverse discrete Fourier Transform.
a
array
required
Input array (typically complex-valued)
s
tuple
default:"None"
Shape of the output along the transformed axes
axes
tuple
default:"None"
Axes over which to compute the inverse FFT. If None, transforms over all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array containing the n-D inverse FFT Example:
X = mx.fft.fftn(x)
x_reconstructed = mx.fft.ifftn(X)

rfftn

mx.fft.rfftn(
    a: array,
    s: tuple = None,
    axes: tuple = None,
    stream: StreamOrDevice = None
) -> array
Compute the n-dimensional FFT of a real-valued input.
a
array
required
Real-valued input array
s
tuple
default:"None"
Shape of the output along the transformed axes
axes
tuple
default:"None"
Axes over which to compute the FFT. If None, transforms over all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Complex-valued array Example:
x = mx.random.uniform(shape=(4, 4, 4))
X = mx.fft.rfftn(x)

irfftn

mx.fft.irfftn(
    a: array,
    s: tuple = None,
    axes: tuple = None,
    stream: StreamOrDevice = None
) -> array
Compute the inverse of rfftn.
a
array
required
Complex-valued input array (output of rfftn)
s
tuple
default:"None"
Shape of the output along the transformed axes
axes
tuple
default:"None"
Axes over which to compute the inverse FFT. If None, transforms over all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Real-valued array Example:
X = mx.fft.rfftn(x)
x_reconstructed = mx.fft.irfftn(X)

Helper Functions

fftshift

mx.fft.fftshift(
    a: array,
    axes: tuple = None,
    stream: StreamOrDevice = None
) -> array
Shift the zero-frequency component to the center of the spectrum. This function swaps half-spaces for all axes, moving the zero-frequency component from the beginning to the center of the array.
a
array
required
Input array
axes
tuple
default:"None"
Axes over which to shift. If None, shifts all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Shifted array Example:
X = mx.fft.fft(x)
X_centered = mx.fft.fftshift(X)

ifftshift

mx.fft.ifftshift(
    a: array,
    axes: tuple = None,
    stream: StreamOrDevice = None
) -> array
The inverse of fftshift. This function undoes the effect of fftshift, shifting the zero-frequency component back to the beginning of the array.
a
array
required
Input array
axes
tuple
default:"None"
Axes over which to shift. If None, shifts all axes.
stream
StreamOrDevice
default:"None"
Stream or device to use
Returns: Shifted array Example:
X_centered = mx.fft.fftshift(X)
X_original = mx.fft.ifftshift(X_centered)

Build docs developers (and LLMs) love