mx.grad() to train a model using stochastic gradient descent.
Setup
First, import the core package and set up the problem metadata:Generate Synthetic Data
We’ll create a synthetic dataset by:Define Loss and Gradient
We’ll use SGD to find the optimal weights. Define the squared loss and get the gradient function of the loss with respect to the parameters:mx.grad() function automatically computes the gradient of the loss function with respect to its input parameter w.
Training Loop
Initialize the parameters randomly, then repeatedly update them fornum_iters iterations:
mx.eval(w) call forces evaluation of the computation graph, which is necessary because MLX uses lazy evaluation.
Evaluate Results
Finally, compute the loss of the learned parameters and verify that they are close to the ground truth parameters:Expected Output
Complete Example
Linear Regression
View the complete linear regression example
Logistic Regression
View a similar logistic regression example