Implementing the Model
We’ll use the neural network building blocks defined in themlx.nn module to concisely define the model architecture.
Attention Layer
We’ll start with the Llama attention layer which uses RoPE positional encoding. Our implementation includes an optional key/value cache for efficient inference. We usemlx.nn.Linear for all projections and mlx.nn.RoPE for positional encoding:
Encoder Layer
The encoder layer uses RMS normalization and SwiGLU activation. We use themlx.nn.RMSNorm layer that’s already provided:
Full Model
To implement any Llama model, we simply combineLlamaEncoderLayer instances with an mlx.nn.Embedding to embed the input tokens:
We use a simple list to hold the encoder layers, but
model.parameters() will still consider these layers.Generation
The__call__ method above is suitable for training but not inference. We need to add a generation method that uses the cache and performs autoregressive sampling:
Using the Model
We now have everything needed to create a Llama model and sample tokens from it:Converting Weights
To use actual Llama weights, you need to convert PyTorch weights to MLX format. Here’s a script that maps PyTorch parameter names to MLX names:Loading Weights and Benchmarking
Load the converted weights usingmlx.utils.tree_unflatten:
tree_unflatten method transforms flat keys like layers.2.attention.query_proj.weight into nested dictionaries that can update the model.