Skip to main content
Let’s implement a basic linear regression model as a starting point to learn MLX. This example demonstrates how to use MLX’s automatic differentiation with mx.grad() to train a model using stochastic gradient descent.

Setup

First, import the core package and set up the problem metadata:
import mlx.core as mx

num_features = 100
num_examples = 1_000
num_iters = 10_000  # iterations of SGD
lr = 0.01  # learning rate for SGD

Generate Synthetic Data

We’ll create a synthetic dataset by:
1

Sample the design matrix X

Generate random input features from a normal distribution
2

Sample ground truth parameters w_star

Create the true parameter vector we want to recover
3

Compute noisy labels y

Calculate dependent values by adding Gaussian noise to X @ w_star
# True parameters
w_star = mx.random.normal((num_features,))

# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))

# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps

Define Loss and Gradient

We’ll use SGD to find the optimal weights. Define the squared loss and get the gradient function of the loss with respect to the parameters:
def loss_fn(w):
    return 0.5 * mx.mean(mx.square(X @ w - y))

grad_fn = mx.grad(loss_fn)
The mx.grad() function automatically computes the gradient of the loss function with respect to its input parameter w.

Training Loop

Initialize the parameters randomly, then repeatedly update them for num_iters iterations:
w = 1e-2 * mx.random.normal((num_features,))

for _ in range(num_iters):
    grad = grad_fn(w)
    w = w - lr * grad
    mx.eval(w)
The mx.eval(w) call forces evaluation of the computation graph, which is necessary because MLX uses lazy evaluation.

Evaluate Results

Finally, compute the loss of the learned parameters and verify that they are close to the ground truth parameters:
loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5

print(f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}")

Expected Output

Loss 0.00005, |w-w*| = 0.00364
The model successfully recovers parameters very close to the ground truth, demonstrating that MLX’s automatic differentiation works correctly for optimization.

Complete Example

Linear Regression

View the complete linear regression example

Logistic Regression

View a similar logistic regression example

Build docs developers (and LLMs) love