mlx.optimizers module provides optimizers for training neural networks. All optimizers work with both mlx.nn modules and pure mlx.core functions.
Quick Start
Here’s a typical training loop with an optimizer:Base Optimizer
Base class for all optimizers.Allows implementing optimizers on a per-parameter basis and applying them to parameter trees.Key Methods:
update(model, gradients): Apply gradients to model parametersinit(parameters): Initialize optimizer stateapply_gradients(gradients, parameters): Apply gradients and return updated parameters
Optimizer Methods
Apply the gradients to the parameters of the model and update the model.Parameters:
model(nn.Module): An MLX module to be updatedgradients(dict): Python tree of gradients, typically fromnn.value_and_grad
Initialize the optimizer’s state.Optional - the optimizer will initialize itself on first update if not called explicitly.Parameters:
parameters(dict): Python tree of parameters
The optimizer’s state dictionary.Contains step count, learning rate, and optimizer-specific state (e.g., momentum).
Common Optimizers
Stochastic Gradient Descent optimizer.Updates:
v_t+1 = μv_t + (1 - τ)g_t and w_t+1 = w_t - λv_t+1Parameters:learning_rate(float or callable): The learning rate λmomentum(float): The momentum strength μ. Default:0weight_decay(float): The weight decay (L2 penalty). Default:0dampening(float): Dampening for momentum τ. Default:0nesterov(bool): Enables Nesterov momentum. Default:False
Adam optimizer.Parameters:
learning_rate(float or callable): The learning rate λbetas(Tuple[float, float]): Coefficients (β₁, β₂) for running averages. Default:(0.9, 0.999)eps(float): Term ε added to denominator for numerical stability. Default:1e-8bias_correction(bool): If True, apply bias correction. Default:False
AdamW optimizer with decoupled weight decay.Parameters:
learning_rate(float or callable): The learning rate αbetas(Tuple[float, float]): Coefficients (β₁, β₂) for running averages. Default:(0.9, 0.999)eps(float): Term ε added to denominator for numerical stability. Default:1e-8weight_decay(float): The weight decay λ. Default:0.01bias_correction(bool): If True, apply bias correction. Default:False
Adamax optimizer, a variant of Adam based on the infinity norm.Parameters:
learning_rate(float or callable): The learning rate λbetas(Tuple[float, float]): Coefficients (β₁, β₂). Default:(0.9, 0.999)eps(float): Term ε added to denominator. Default:1e-8
Lion optimizer.Recommended to use a learning rate 3-10x smaller than AdamW and weight decay 3-10x larger.Parameters:
learning_rate(float or callable): The learning rate ηbetas(Tuple[float, float]): Coefficients (β₁, β₂). Default:(0.9, 0.99)weight_decay(float): The weight decay λ. Default:0.0
Adagrad optimizer.Parameters:
learning_rate(float or callable): The learning rate λeps(float): Term ε added to denominator for numerical stability. Default:1e-8
AdaDelta optimizer with a learning rate.Parameters:
learning_rate(float or callable): The learning rate λrho(float): Coefficient ρ for computing running average of squared gradients. Default:0.9eps(float): Term ε added to denominator for numerical stability. Default:1e-6
RMSprop optimizer.Parameters:
learning_rate(float or callable): The learning rate λalpha(float): The smoothing constant α. Default:0.99eps(float): Term ε added to denominator for numerical stability. Default:1e-8
Adafactor optimizer with adaptive learning rates and sublinear memory cost.Parameters:
learning_rate(float or callable): The learning rate. Default:Noneeps(tuple): (ε₁, ε₂) for numerical stability and parameter scaling. Default:(1e-30, 1e-3)clip_threshold(float): Clips unscaled update at this threshold. Default:1.0decay_rate(float): Coefficient for running average of squared gradient. Default:-0.8beta_1(float): If set, use first moment. Default:Noneweight_decay(float): The weight decay λ. Default:0.0scale_parameter(bool): If True, scale learning rate by RMS of parameters. Default:Truerelative_step(bool): If True, use relative step size. Default:Truewarmup_init(bool): If True, calculate step size by current step. Default:False
Muon (MomentUm Orthogonalized by Newton-schulz) optimizer.Note: Muon may be sub-optimal for embedding layers, final fully connected layers, or 0D/1D parameters. Use a different optimizer (e.g., AdamW) for those.Parameters:
learning_rate(float or callable): The learning ratemomentum(float): The momentum strength. Default:0.95weight_decay(float): The weight decay (L2 penalty). Default:0.01nesterov(bool): Enables Nesterov momentum. Default:Truens_steps(int): Number of Newton-Schulz iteration steps. Default:5
Multi-Optimizer
Wraps multiple optimizers with weight predicates to use different optimizers for different parameters.Parameters:
optimizers(list[Optimizer]): List of optimizers to delegate tofilters(list[Callable]): List of predicates (one less than optimizers). Last optimizer is fallback.
Learning Rate Schedulers
Learning rate schedulers can be passed directly to optimizers:Make an exponential decay scheduler.Parameters:
init(float): Initial valuedecay_rate(float): Multiplicative factor to decay by
Make a step decay scheduler.Parameters:
init(float): Initial valuedecay_rate(float): Multiplicative factor to decay bystep_size(int): Decay everystep_sizesteps
Make a cosine decay scheduler.Parameters:
init(float): Initial valuedecay_steps(int): Number of steps to decay overend(float): Final value to decay to. Default:0.0
Make a linear scheduler.Parameters:
init(float): Initial valueend(float): Final valuesteps(int): Number of steps to apply schedule over
Join multiple schedules to create a new schedule.Parameters:
schedules(list[Callable]): List of schedulesboundaries(list[int]): Boundaries indicating when to transition between schedules
Gradient Clipping
Clips the global norm of the gradients.Ensures that the global norm of gradients does not exceed
max_norm. Scales down gradients proportionally if needed.Parameters:grads(dict): Dictionary containing gradient arraysmax_norm(float): Maximum allowed global norm of gradients
(dict, float): Clipped gradients and original gradient norm
Saving and Loading
To serialize an optimizer, save its state. To load an optimizer, load and set the saved state.betas and eps are not. As a rule of thumb, if a parameter can be scheduled, it will be included in the optimizer state.