Overview
The mistral-mlx crate provides high-performance inference for Mistral language models on Apple Silicon. Mistral is optimized for pre-quantized 4-bit models with async pipelining for maximum throughput.
Key features
- Optimized for 4-bit quantization - Excellent quality with ~4x memory reduction
- Async pipelining - Maximizes throughput on Apple Silicon
- Grouped Query Attention (GQA) - Efficient multi-head attention
- ~74 tok/s on M-series - High performance on Mistral-7B-4bit
Installation
Add to your Cargo.toml:
[dependencies]
mistral-mlx = "0.1"
Core functions
load_model
Loads a Mistral model from a directory containing weights and configuration.
pub fn load_model(model_dir: impl AsRef<Path>) -> Result<Model, Error>
Path to the model directory containing:
config.json - Model configuration
model.safetensors.index.json - Weight file index
model-*.safetensors - Model weights
Returns a loaded Model ready for inference
load_tokenizer
Loads the tokenizer from the model directory.
pub fn load_tokenizer(model_dir: impl AsRef<Path>) -> Result<Tokenizer, Error>
Path to the model directory containing tokenizer.json
Returns a HuggingFace Tokenizer instance
get_model_args
Parses model configuration from config.json.
pub fn get_model_args(model_dir: impl AsRef<Path>) -> Result<ModelArgs, Error>
Path to directory containing config.json
Returns parsed ModelArgs with model hyperparameters
init_cache
Initializes an empty KV cache for the specified number of layers.
pub fn init_cache<C: Default>(num_layers: usize) -> Vec<C>
Number of transformer layers (from ModelArgs::num_hidden_layers)
Returns a vector of default-initialized cache entries
Types
Model
The main model struct for Mistral inference.
pub struct Model {
pub args: ModelArgs,
pub model: MistralModel,
pub lm_head: MaybeQuantized<nn::Linear>,
}
Model configuration and hyperparameters
The core Mistral transformer model
lm_head
MaybeQuantized<nn::Linear>
Language modeling head projection
ModelArgs
Mistral model configuration.
pub struct ModelArgs {
pub hidden_size: i32,
pub num_hidden_layers: i32,
pub head_dim: i32,
pub intermediate_size: i32,
pub num_attention_heads: i32,
pub num_key_value_heads: i32,
pub rms_norm_eps: f32,
pub vocab_size: i32,
pub rope_theta: f32,
pub quantization: Option<QuantizationConfig>,
pub tie_word_embeddings: bool,
}
Number of KV heads for Grouped Query Attention (GQA)
Dimension of each attention head
Generate
Iterator for autoregressive text generation.
pub struct Generate<'a, C: KeyValueCache> {
model: &'a mut Model,
cache: &'a mut Vec<Option<C>>,
temp: f32,
state: GenerateState<'a>,
prefetched: Option<Array>,
token_count: usize,
}
Constructor
pub fn new(
model: &'a mut Model,
cache: &'a mut Vec<Option<C>>,
temp: f32,
prompt_token: &'a Array,
) -> Self
Mutable reference to the loaded model
cache
&'a mut Vec<Option<C>>
required
KV cache for attention (use init_cache or empty Vec)
Sampling temperature (0.0 = greedy, higher = more random)
Encoded prompt tokens as MLX array with shape [1, seq_len]
Example usage
Basic generation
use mistral_mlx::{load_model, load_tokenizer, Generate, KVCache};
use mlx_rs::ops::indexing::NewAxis;
let model_dir = "models/Mistral-7B-Instruct-v0.2";
// Load model and tokenizer
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;
// Encode prompt
let encoding = tokenizer.encode("Once upon a time, ", true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);
// Initialize cache
let mut cache = Vec::new();
// Generate tokens
let generator = Generate::<KVCache>::new(&mut model, &mut cache, 0.7, &prompt);
for token in generator.take(100) {
let token = token?;
let text = tokenizer.decode(&[token.item::<u32>()], true)?;
print!("{}", text);
}
use mistral_mlx::{load_model, load_tokenizer, Generate, KVCache};
use mlx_rs::ops::indexing::NewAxis;
let model_dir = "models/Mistral-7B-Instruct-v0.2";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;
// Format instruction prompt
let instruction = "Explain the theory of relativity in simple terms.";
let prompt_text = format!("[INST] {} [/INST]", instruction);
let encoding = tokenizer.encode(&prompt_text, true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);
let mut cache = Vec::new();
let generator = Generate::<KVCache>::new(&mut model, &mut cache, 0.6, &prompt);
for token in generator.take(200) {
let token = token?;
let id = token.item::<u32>();
// Stop on EOS token
if id == 2 {
break;
}
let text = tokenizer.decode(&[id], true)?;
print!("{}", text);
}
Greedy decoding
use mistral_mlx::{load_model, load_tokenizer, Generate, KVCache};
use mlx_rs::ops::indexing::NewAxis;
let model_dir = "models/Mistral-7B-Instruct-v0.2";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;
let encoding = tokenizer.encode("The capital of France is", true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);
let mut cache = Vec::new();
// Temperature 0.0 = greedy (always pick most likely token)
let generator = Generate::<KVCache>::new(&mut model, &mut cache, 0.0, &prompt);
for token in generator.take(10) {
let token = token?;
let text = tokenizer.decode(&[token.item::<u32>()], true)?;
print!("{}", text);
}
Architecture components
Attention
Grouped Query Attention with RoPE.
pub struct Attention {
pub n_heads: i32,
pub n_kv_heads: i32,
pub head_dim: i32,
pub scale: f32,
pub q_proj: MaybeQuantized<nn::Linear>,
pub k_proj: MaybeQuantized<nn::Linear>,
pub v_proj: MaybeQuantized<nn::Linear>,
pub o_proj: MaybeQuantized<nn::Linear>,
pub rope: nn::Rope,
}
Number of key-value heads (typically 8 for Mistral-7B)
Attention scaling factor: 1.0 / sqrt(head_dim)
FeedForward
MLP with SwiGLU activation.
pub struct FeedForward {
pub gate_proj: MaybeQuantized<nn::Linear>,
pub down_proj: MaybeQuantized<nn::Linear>,
pub up_proj: MaybeQuantized<nn::Linear>,
}
Single transformer layer.
pub struct TransformerBlock {
pub self_attn: Attention,
pub mlp: FeedForward,
pub input_layernorm: nn::RmsNorm,
pub post_attention_layernorm: nn::RmsNorm,
}
Grouped Query Attention (GQA)
Mistral uses GQA to reduce memory bandwidth:
// Q has more heads than K/V
let n_heads = 32; // Query heads
let n_kv_heads = 8; // Key/Value heads
let repeats = n_heads / n_kv_heads; // 4
// Repeat K/V to match Q
let k = k.repeat(&[1, 1, repeats, 1])?;
let v = v.repeat(&[1, 1, repeats, 1])?;
// Now all have 32 heads for attention
let scores = q.matmul(&k.transpose(-2, -1)?)?;
- 4-bit quantization reduces memory by ~4x with minimal quality loss
- GQA reduces KV cache size by 4x (8 KV heads vs 32 Q heads)
- Async pipelining overlaps compute and memory operations
- Metal acceleration achieves ~74 tok/s on M-series Macs
Benchmarks
| Model | Precision | M1 Max | M2 Ultra |
|---|
| Mistral-7B | FP16 | ~32 tok/s | ~45 tok/s |
| Mistral-7B | 4-bit | ~74 tok/s | ~110 tok/s |
Memory requirements
| Model | Precision | Memory |
|---|
| Mistral-7B | FP16 | ~14 GB |
| Mistral-7B | 4-bit | ~4 GB |
See also