Skip to main content

Overview

The mixtral-mlx crate provides high-performance inference for Mixtral MoE (Mixture of Experts) models on Apple Silicon. Mixtral uses sparse expert routing with 8 experts and top-2 selection per token, enabling larger model capacity with efficient computation.

Key features

  • 8 experts with top-2 routing - Each token activates only 2 of 8 expert networks
  • Custom fused SwiGLU kernel - 10-12x faster expert computation via Metal
  • Optimized gather_qmm - Efficient expert dispatch and gathering
  • 4-bit quantization - Reduced memory for running large MoE models

Installation

Add to your Cargo.toml:
[dependencies]
mixtral-mlx = "0.1"

Core functions

load_model

Loads a Mixtral model from a directory containing weights and configuration.
pub fn load_model(model_dir: impl AsRef<Path>) -> Result<Model, Error>
model_dir
impl AsRef<Path>
required
Path to the model directory containing:
  • config.json - Model configuration
  • model.safetensors.index.json - Weight file index
  • model-*.safetensors - Model weights
Result<Model, Error>
Result
Returns a loaded Model ready for inference

load_tokenizer

Loads the tokenizer from the model directory.
pub fn load_tokenizer(model_dir: impl AsRef<Path>) -> Result<Tokenizer, Error>
model_dir
impl AsRef<Path>
required
Path to the model directory containing tokenizer.json
Result<Tokenizer, Error>
Result
Returns a HuggingFace Tokenizer instance

get_model_args

Parses model configuration from config.json.
pub fn get_model_args(model_dir: impl AsRef<Path>) -> Result<ModelArgs, Error>
model_dir
impl AsRef<Path>
required
Path to directory containing config.json
Result<ModelArgs, Error>
Result
Returns parsed ModelArgs with model hyperparameters

Types

Model

The main model struct for Mixtral inference.
pub struct Model {
    pub args: ModelArgs,
    pub model: MixtralModel,
    pub lm_head: MaybeQuantized<nn::Linear>,
}
args
ModelArgs
Model configuration and hyperparameters
model
MixtralModel
The core Mixtral transformer model
lm_head
MaybeQuantized<nn::Linear>
Language modeling head projection

ModelArgs

Mixtral model configuration.
pub struct ModelArgs {
    pub model_type: String,
    pub vocab_size: i32,
    pub hidden_size: i32,
    pub intermediate_size: i32,
    pub num_hidden_layers: i32,
    pub num_attention_heads: i32,
    pub num_experts_per_tok: i32,     // Typically 2
    pub num_key_value_heads: Option<i32>,
    pub num_local_experts: i32,       // Typically 8
    pub rms_norm_eps: f32,
    pub rope_theta: f32,
    pub rope_traditional: bool,
    pub tie_word_embeddings: bool,
    pub quantization: Option<QuantizationConfig>,
}
num_experts_per_tok
i32
default:"2"
Number of experts to activate per token (top-k routing)
num_local_experts
i32
default:"8"
Total number of expert networks in each MoE layer

Generate

Iterator for autoregressive text generation.
pub struct Generate<'a, C: KeyValueCache> {
    model: &'a mut Model,
    cache: &'a mut Vec<Option<C>>,
    temp: f32,
    state: GenerateState<'a>,
    prefetched: Option<Array>,
    token_count: usize,
}

Constructor

pub fn new(
    model: &'a mut Model,
    cache: &'a mut Vec<Option<C>>,
    temp: f32,
    prompt_token: &'a Array,
) -> Self
model
&'a mut Model
required
Mutable reference to the loaded model
cache
&'a mut Vec<Option<C>>
required
KV cache for attention (initially empty)
temp
f32
required
Sampling temperature (0.0 = greedy, higher = more random)
prompt_token
&'a Array
required
Encoded prompt tokens as MLX array with shape [1, seq_len]

Example usage

Basic generation

use mixtral_mlx::{load_model, load_tokenizer, Generate, KVCache};
use mlx_rs::ops::indexing::NewAxis;

let model_dir = "models/Mixtral-8x7B-Instruct-v0.1";

// Load model and tokenizer
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;

// Encode prompt
let encoding = tokenizer.encode("Explain quantum computing in simple terms: ", true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);

// Initialize cache
let mut cache = Vec::new();

// Generate tokens
let generator = Generate::<KVCache>::new(&mut model, &mut cache, 0.6, &prompt);

for token in generator.take(200) {
    let token = token?;
    let text = tokenizer.decode(&[token.item::<u32>()], true)?;
    print!("{}", text);
}

With Mistral instruction format

use mixtral_mlx::{load_model, load_tokenizer, Generate, KVCache};
use mlx_rs::ops::indexing::NewAxis;

let model_dir = "models/Mixtral-8x7B-Instruct-v0.1";
let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;

// Format instruction prompt
let instruction = "Write a Python function to calculate Fibonacci numbers.";
let prompt_text = format!("[INST] {} [/INST]", instruction);

let encoding = tokenizer.encode(&prompt_text, true)?;
let prompt = mlx_rs::Array::from(encoding.get_ids()).index(NewAxis);

let mut cache = Vec::new();
let generator = Generate::<KVCache>::new(&mut model, &mut cache, 0.5, &prompt);

for token in generator.take(300) {
    let token = token?;
    let id = token.item::<u32>();
    
    // Stop on EOS
    if id == 2 {
        break;
    }
    
    let text = tokenizer.decode(&[id], true)?;
    print!("{}", text);
}

Loading quantized models

use mixtral_mlx::{load_model, load_tokenizer};

// 4-bit quantized model significantly reduces memory
let model_dir = "models/Mixtral-8x7B-Instruct-v0.1-4bit";

let tokenizer = load_tokenizer(model_dir)?;
let mut model = load_model(model_dir)?;  // Automatically detects quantization

// Use normally - quantization is transparent
// Memory usage: ~24GB -> ~6GB with 4-bit

Architecture components

MixtralSparseMoeBlock

Sparse Mixture of Experts layer with top-k routing.
pub struct MixtralSparseMoeBlock {
    pub num_experts: i32,
    pub top_k: i32,
    pub gate: nn::Linear,
    pub experts: Vec<SwitchGLU>,
}
num_experts
i32
Total number of expert networks (typically 8)
top_k
i32
Number of experts to select per token (typically 2)
gate
nn::Linear
Router network that scores all experts for each token
experts
Vec<SwitchGLU>
Vector of expert feed-forward networks

SwitchGLU

Expert network using SwiGLU activation.
pub struct SwitchGLU {
    pub gate_proj: MaybeQuantized<nn::Linear>,
    pub up_proj: MaybeQuantized<nn::Linear>,
    pub down_proj: MaybeQuantized<nn::Linear>,
}

Attention

Standard multi-head attention with GQA and RoPE.
pub struct Attention {
    pub n_heads: i32,
    pub n_kv_heads: i32,
    pub scale: f32,
    pub q_proj: MaybeQuantized<nn::Linear>,
    pub k_proj: MaybeQuantized<nn::Linear>,
    pub v_proj: MaybeQuantized<nn::Linear>,
    pub o_proj: MaybeQuantized<nn::Linear>,
    pub rope: nn::Rope,
}

DecoderLayer

Transformer layer with MoE instead of standard MLP.
pub struct DecoderLayer {
    pub self_attn: Attention,
    pub block_sparse_moe: MixtralSparseMoeBlock,
    pub input_layernorm: nn::RmsNorm,
    pub post_attention_layernorm: nn::RmsNorm,
}

Mixture of Experts routing

Mixtral uses top-k sparse routing:
// For each token, compute router logits for all experts
let router_logits = self.gate.forward(x)?;  // [batch, seq, num_experts]

// Select top-k experts
let (top_k_logits, top_k_indices) = router_logits.top_k(self.top_k)?;

// Softmax over selected experts
let routing_weights = mlx_rs::ops::softmax(&top_k_logits, -1)?;

// Dispatch to selected experts and combine outputs
let expert_outputs = dispatch_to_experts(x, &top_k_indices)?;
let output = combine_expert_outputs(&expert_outputs, &routing_weights)?;

Performance notes

  • Sparse activation - Only 2/8 experts active per token (~25% FLOPs)
  • Fused SwiGLU kernel - 10-12x speedup for expert computation
  • 4-bit quantization - Enables running 8x7B models on consumer hardware
  • Metal acceleration - Optimized GPU kernels for Apple Silicon

Memory requirements

ModelPrecisionMemory
Mixtral-8x7BFP16~90 GB
Mixtral-8x7B4-bit~24 GB
Mixtral-8x22B4-bit~48 GB

See also

Build docs developers (and LLMs) love