Overview
The mixtral-mlx crate provides high-performance inference for Mixtral MoE (Mixture of Experts) models on Apple Silicon. Mixtral uses sparse expert routing with 8 experts and top-2 selection per token, enabling larger model capacity with efficient computation.
Key features
- 8 experts with top-2 routing - Each token activates only 2 of 8 expert networks
- Custom fused SwiGLU kernel - 10-12x faster expert computation via Metal
- Optimized gather_qmm - Efficient expert dispatch and gathering
- 4-bit quantization - Reduced memory for running large MoE models
Installation
Add to your Cargo.toml:
[dependencies]
mixtral-mlx = "0.1"
Core functions
load_model
Loads a Mixtral model from a directory containing weights and configuration.
pub fn load_model(model_dir: impl AsRef<Path>) -> Result<Model, Error>
Path to the model directory containing:
config.json - Model configuration
model.safetensors.index.json - Weight file index
model-*.safetensors - Model weights
Returns a loaded Model ready for inference
load_tokenizer
Loads the tokenizer from the model directory.
pub fn load_tokenizer(model_dir: impl AsRef<Path>) -> Result<Tokenizer, Error>
Path to the model directory containing tokenizer.json
Returns a HuggingFace Tokenizer instance
get_model_args
Parses model configuration from config.json.
pub fn get_model_args(model_dir: impl AsRef<Path>) -> Result<ModelArgs, Error>
Path to directory containing config.json
Returns parsed ModelArgs with model hyperparameters
Types
Model
The main model struct for Mixtral inference.
pub struct Model {
pub args: ModelArgs,
pub model: MixtralModel,
pub lm_head: MaybeQuantized<nn::Linear>,
}
Model configuration and hyperparameters
The core Mixtral transformer model
lm_head
MaybeQuantized<nn::Linear>
Language modeling head projection
ModelArgs
Mixtral model configuration.
pub struct ModelArgs {
pub model_type: String,
pub vocab_size: i32,
pub hidden_size: i32,
pub intermediate_size: i32,
pub num_hidden_layers: i32,
pub num_attention_heads: i32,
pub num_experts_per_tok: i32, // Typically 2
pub num_key_value_heads: Option<i32>,
pub num_local_experts: i32, // Typically 8
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub rope_traditional: bool,
pub tie_word_embeddings: bool,
pub quantization: Option<QuantizationConfig>,
}
Number of experts to activate per token (top-k routing)
Total number of expert networks in each MoE layer
Generate
Iterator for autoregressive text generation.
pub struct Generate<'a, C: KeyValueCache> {
model: &'a mut Model,
cache: &'a mut Vec<Option<C>>,
temp: f32,
state: GenerateState<'a>,
prefetched: Option<Array>,
token_count: usize,
}
Constructor
pub fn new(
model: &'a mut Model,
cache: &'a mut Vec<Option<C>>,
temp: f32,
prompt_token: &'a Array,
) -> Self
Mutable reference to the loaded model
cache
&'a mut Vec<Option<C>>
required
KV cache for attention (initially empty)
Sampling temperature (0.0 = greedy, higher = more random)
Encoded prompt tokens as MLX array with shape [1, seq_len]
Example usage
Basic generation
use mixtral_mlx::{load_model, load_tokenizer, Generate, KVCache};
use mlx_rs::ops::indexing::NewAxis;
let model_dir = "models/Mixtral-8x7B-Instruct-v0.1";
// Load model and tokenizer
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;
// Encode prompt
let encoding = tokenizer.encode("Explain quantum computing in simple terms: ", true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);
// Initialize cache
let mut cache = Vec::new();
// Generate tokens
let generator = Generate::<KVCache>::new(&mut model, &mut cache, 0.6, &prompt);
for token in generator.take(200) {
let token = token?;
let text = tokenizer.decode(&[token.item::<u32>()], true)?;
print!("{}", text);
}
use mixtral_mlx::{load_model, load_tokenizer, Generate, KVCache};
use mlx_rs::ops::indexing::NewAxis;
let model_dir = "models/Mixtral-8x7B-Instruct-v0.1";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;
// Format instruction prompt
let instruction = "Write a Python function to calculate Fibonacci numbers.";
let prompt_text = format!("[INST] {} [/INST]", instruction);
let encoding = tokenizer.encode(&prompt_text, true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);
let mut cache = Vec::new();
let generator = Generate::<KVCache>::new(&mut model, &mut cache, 0.5, &prompt);
for token in generator.take(300) {
let token = token?;
let id = token.item::<u32>();
// Stop on EOS
if id == 2 {
break;
}
let text = tokenizer.decode(&[id], true)?;
print!("{}", text);
}
Loading quantized models
use mixtral_mlx::{load_model, load_tokenizer};
// 4-bit quantized model significantly reduces memory
let model_dir = "models/Mixtral-8x7B-Instruct-v0.1-4bit";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?; // Automatically detects quantization
// Use normally - quantization is transparent
// Memory usage: ~24GB -> ~6GB with 4-bit
Architecture components
MixtralSparseMoeBlock
Sparse Mixture of Experts layer with top-k routing.
pub struct MixtralSparseMoeBlock {
pub num_experts: i32,
pub top_k: i32,
pub gate: nn::Linear,
pub experts: Vec<SwitchGLU>,
}
Total number of expert networks (typically 8)
Number of experts to select per token (typically 2)
Router network that scores all experts for each token
Vector of expert feed-forward networks
SwitchGLU
Expert network using SwiGLU activation.
pub struct SwitchGLU {
pub gate_proj: MaybeQuantized<nn::Linear>,
pub up_proj: MaybeQuantized<nn::Linear>,
pub down_proj: MaybeQuantized<nn::Linear>,
}
Attention
Standard multi-head attention with GQA and RoPE.
pub struct Attention {
pub n_heads: i32,
pub n_kv_heads: i32,
pub scale: f32,
pub q_proj: MaybeQuantized<nn::Linear>,
pub k_proj: MaybeQuantized<nn::Linear>,
pub v_proj: MaybeQuantized<nn::Linear>,
pub o_proj: MaybeQuantized<nn::Linear>,
pub rope: nn::Rope,
}
DecoderLayer
Transformer layer with MoE instead of standard MLP.
pub struct DecoderLayer {
pub self_attn: Attention,
pub block_sparse_moe: MixtralSparseMoeBlock,
pub input_layernorm: nn::RmsNorm,
pub post_attention_layernorm: nn::RmsNorm,
}
Mixture of Experts routing
Mixtral uses top-k sparse routing:
// For each token, compute router logits for all experts
let router_logits = self.gate.forward(x)?; // [batch, seq, num_experts]
// Select top-k experts
let (top_k_logits, top_k_indices) = router_logits.top_k(self.top_k)?;
// Softmax over selected experts
let routing_weights = mlx_rs::ops::softmax(&top_k_logits, -1)?;
// Dispatch to selected experts and combine outputs
let expert_outputs = dispatch_to_experts(x, &top_k_indices)?;
let output = combine_expert_outputs(&expert_outputs, &routing_weights)?;
- Sparse activation - Only 2/8 experts active per token (~25% FLOPs)
- Fused SwiGLU kernel - 10-12x speedup for expert computation
- 4-bit quantization - Enables running 8x7B models on consumer hardware
- Metal acceleration - Optimized GPU kernels for Apple Silicon
Memory requirements
| Model | Precision | Memory |
|---|
| Mixtral-8x7B | FP16 | ~90 GB |
| Mixtral-8x7B | 4-bit | ~24 GB |
| Mixtral-8x22B | 4-bit | ~48 GB |
See also