Skip to main content

mlx-rs-core

Shared inference infrastructure used across all model-specific MLX Rust crates (qwen3-mlx, glm4-mlx, gpt-sovits-mlx, etc.).

Overview

mlx-rs-core provides common components for efficient inference:
  • KV cache - Fast autoregressive decoding with key-value caching
  • Token generation - Generic generation infrastructure with sampling
  • Attention utilities - RoPE, attention masks, scaled dot-product attention
  • Custom kernels - Fused Metal kernels for performance-critical operations
  • Audio processing - Mel spectrogram and audio utilities
  • Error handling - Common error types and conversions

Module exports

pub use cache::{ConcatKeyValueCache, KVCache, KeyValueCache};
pub use error::{Error, Result};
pub use metal_kernels::{fused_swiglu, fused_modulate};
pub use sampler::{DefaultSampler, Sampler};
pub use utils::{
    create_attention_mask, initialize_rope, scaled_dot_product_attention,
    AttentionMask, FloatOrString, SdpaMask,
};
pub use tokenizers::Tokenizer;

Cache

KV cache implementations for efficient autoregressive generation.

KeyValueCache trait

KeyValueCache
trait
Trait for key-value caches used in attention mechanisms
pub trait KeyValueCache {
    fn offset(&self) -> i32;
    fn max_size(&self) -> Option<i32>;
    fn update_and_fetch(
        &mut self, 
        keys: Array, 
        values: Array
    ) -> Result<(Array, Array), Exception>;
}
offset
fn() -> i32
Returns current cache offset (number of tokens cached)
max_size
fn() -> Option<i32>
Returns maximum cache size for sliding window attention, if any
update_and_fetch
fn
Update cache with new keys/values and return full cache contents
keys
Array
New key tensors to add to cache
values
Array
New value tensors to add to cache
returns
(Array, Array)
Tuple of (full_keys, full_values) including new and cached tokens

ConcatKeyValueCache

ConcatKeyValueCache
struct
Simple concatenation-based KV cacheCaches by concatenating new keys/values with existing ones. Simple but can be slow for long sequences.
use mlx_rs_core::ConcatKeyValueCache;

let mut cache = ConcatKeyValueCache::new();

KVCache

KVCache
struct
Optimized step-based KV cache with pre-allocationPre-allocates buffers in steps of 256 tokens and uses in-place updates, avoiding expensive concatenation. Matches Python mlx-lm implementation.
use mlx_rs_core::KVCache;

// Default: 256 token step size
let mut cache = KVCache::new();

// Custom step size
let mut cache = KVCache::with_step(512);
KVCache::new
fn() -> Self
Create cache with default step size of 256 tokens
KVCache::with_step
fn(step: i32) -> Self
Create cache with custom step size
step
i32
Number of tokens to pre-allocate per growth step

Generate

Generic token generation infrastructure with builder pattern.

Generate struct

Generate
struct
Iterator-based token generator with configurable sampling
use mlx_rs_core::Generate;

let generator = Generate::builder()
    .tokenizer(tokenizer)
    .model(model)
    .prompt(prompt_tokens)
    .temp(0.7)
    .max_tokens(256)
    .build();
    
for response in generator {
    let response = response?;
    println!("Generated: {}", response.text);
}

Builder

Generate::builder
fn() -> Builder
Create new generation builder
Builder::tokenizer
fn(tokenizer: Tokenizer) -> Self
Set tokenizer for decoding tokens to text
Builder::model
fn(model: M) -> Self
Set model for generation (must implement Module trait)
Builder::prompt
fn(prompt: Array) -> Self
Set prompt tokens as Array
Builder::temp
fn(temp: f32) -> Self
Set sampling temperature (default: 0.0 for greedy)
  • temp = 0.0 - Greedy decoding (argmax)
  • temp > 0.0 - Sampling with temperature
Builder::max_tokens
fn(max_tokens: usize) -> Self
Set maximum number of tokens to generate (default: 256)
Builder::sampler
fn(sampler: S) -> Self
Set custom sampler implementation
Builder::build
fn() -> Generate
Build the generator (consumes builder)

Response

Response
struct
Generated text response
pub struct Response {
    pub text: String,
    pub ids: Vec<u32>,
}
text
String
Decoded text from generated tokens
ids
Vec<u32>
Generated token IDs

Sampler

Token sampling strategies for generation.

Sampler trait

Sampler
trait
Trait for implementing custom sampling strategies
pub trait Sampler {
    fn sample(&mut self, logits: &Array, temp: f32) -> Result<Array, Exception>;
}
sample
fn(&mut self, logits: &Array, temp: f32) -> Result<Array>
Sample next token from logits
logits
&Array
Model output logits (unnormalized scores)
temp
f32
Temperature for sampling (0.0 = greedy)
returns
Array
Sampled token ID(s)

DefaultSampler

DefaultSampler
struct
Default sampling implementation
  • Temperature 0.0: Greedy decoding (argmax)
  • Temperature > 0.0: Categorical sampling with temperature scaling
use mlx_rs_core::DefaultSampler;

let mut sampler = DefaultSampler;
let token = sampler.sample(&logits, 0.7)?;

Utilities (utils)

Attention utilities, RoPE initialization, and helper functions.

RoPE initialization

initialize_rope
fn
Initialize rotary position embedding
use mlx_rs_core::initialize_rope;

let rope = initialize_rope(
    dims,           // Head dimension
    base,           // Base frequency (typically 10000.0)
    traditional,    // Traditional vs modern RoPE
    &scaling_config, // Optional scaling config
    max_position_embeddings
)?;
dims
i32
Dimension of each attention head
base
f32
Base frequency for rotations (typically 10000.0)
traditional
bool
Use traditional RoPE formulation
scaling_config
&Option<HashMap<String, FloatOrString>>
Optional scaling configuration for extended context
  • "type": “default”, “linear”, etc.
  • "factor": Scaling factor
max_position_embeddings
i32
Maximum sequence length
returns
Result<nn::Rope>
Configured RoPE module

Attention masks

create_attention_mask
fn
Create causal attention mask for autoregressive generation
use mlx_rs_core::create_attention_mask;

let mask = create_attention_mask(
    &hidden_states,
    &cache,
    Some(return_array)
)?;
h
&Array
Hidden states tensor (shape: [batch, seq_len, …])
cache
&[Option<C>]
KV cache array (one per layer)
return_array
Option<bool>
Force returning explicit mask array instead of hardware causal
returns
Option<AttentionMask>
  • None for single token (no mask needed)
  • Some(AttentionMask::Causal) for hardware-optimized causal
  • Some(AttentionMask::Array) for explicit mask array
AttentionMask
enum
Attention mask variants
pub enum AttentionMask {
    Array(Array),  // Explicit mask array
    Causal,        // Hardware-optimized causal mask
}
SdpaMask
enum
Mask type for scaled dot-product attention
pub enum SdpaMask<'a> {
    Causal,           // Hardware causal mask
    Array(&'a Array), // Explicit mask reference
}

Scaled dot-product attention

scaled_dot_product_attention
fn
Compute scaled dot-product attention
use mlx_rs_core::{scaled_dot_product_attention, SdpaMask};

let output = scaled_dot_product_attention(
    queries,
    keys,
    values,
    None,           // Optional cache
    scale,          // 1.0 / sqrt(head_dim)
    Some(SdpaMask::Causal)
)?;
queries
Array
Query tensor [batch, n_heads, seq_q, head_dim]
keys
Array
Key tensor [batch, n_kv_heads, seq_k, head_dim]
values
Array
Value tensor [batch, n_kv_heads, seq_v, head_dim]
cache
Option<C>
Optional KV cache
scale
f32
Attention scale factor (typically 1.0 / sqrt(head_dim))
mask
Option<SdpaMask>
Optional attention mask
returns
Array
Attention output [batch, n_heads, seq_q, head_dim]

Metal kernels

Custom fused Metal kernels for performance-critical operations.

fused_swiglu

fused_swiglu
fn(x: &Array, gate: &Array) -> Result<Array>
Fused SwiGLU activation using custom Metal kernelComputes: silu(gate) * x = (gate / (1 + exp(-gate))) * xPerformance: 10-12x faster than separate silu() + multiply() calls. Critical for MoE models with many SwiGLU operations.
use mlx_rs_core::fused_swiglu;

let x = array!([1.0, 2.0, 3.0]);
let gate = array!([0.5, 1.0, 1.5]);
let output = fused_swiglu(&x, &gate)?;
x
&Array
Input tensor (any shape)
gate
&Array
Gate tensor (same shape as x)
returns
Array
SwiGLU output (same shape as inputs)

fused_modulate

fused_modulate
fn(x: &Array, shift: &Array, scale: &Array) -> Result<Array>
Fused LayerNorm + Modulation using custom Metal kernelComputes: (1 + scale) * LayerNorm(x) + shiftwhere LayerNorm has no learnable parameters (elementwise_affine=False).Performance: Fuses 7+ operations into single kernel. Critical for DiT (Diffusion Transformer) models:
  • 4 modulate calls per block
  • 60 blocks
  • 40 forward passes per generation
  • = 9,600 modulate calls per image
use mlx_rs_core::fused_modulate;

let x = array!([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let shift = array!([0.1, 0.2, 0.3]);
let scale = array!([0.5, 0.6, 0.7]);

let output = fused_modulate(&x, &shift, &scale)?;
x
&Array
Input tensor [batch, seq, dim] or [seq, dim]
shift
&Array
Shift tensor (flattened to [dim])
scale
&Array
Scale tensor (flattened to [dim])
returns
Array
Modulated output (same shape as x)

Model input/output traits

Traits for generic generation infrastructure.

ModelInput trait

ModelInput
trait
Trait for model input types that can be constructed from builder
pub trait ModelInput<'a, C, T> {
    fn from_model_input_builder(builder: ModelInputBuilder<'a, C, T>) -> Self;
}
Models implement this to receive tokens, cache, and state during generation.

ModelOutput trait

ModelOutput
trait
Trait for model output types that provide logits
pub trait ModelOutput {
    fn logits(&self) -> &Array;
}
Implemented by model outputs to extract next-token logits.

Tokenizer loading

load_tokenizer
fn(model_dir: impl AsRef<Path>) -> Result<Tokenizer>
Load tokenizer from model directoryLoads tokenizer.json from the specified directory.
use mlx_rs_core::load_tokenizer;

let tokenizer = load_tokenizer("models/qwen-1.5-0.5b")?;
model_dir
impl AsRef<Path>
Path to model directory containing tokenizer.json
returns
Result<Tokenizer>
Loaded tokenizer instance

Error handling

Error
enum
Common error type for mlx-rs-core operations
pub enum Error {
    Mlx(Exception),
    Tokenizer(String),
    // ... other variants
}
Result
type
Result type alias with mlx-rs-core Error
pub type Result<T> = std::result::Result<T, Error>;

Helper macros

try_unwrap!
macro
Helper macro for early returns in iterator contexts
// In iterator that returns Option<Result<T, E>>
let value = try_unwrap!(some_operation());
// Automatically converts errors and returns Some(Err(e))

Audio processing

The audio module provides utilities for audio processing:
  • Mel spectrogram computation
  • Audio preprocessing for speech models
  • Feature extraction utilities
(See audio module documentation for detailed API)

Speculative decoding

The speculative module provides support for speculative decoding to accelerate generation:
  • Draft model integration
  • Verification and acceptance logic
  • Multi-token generation strategies
(See speculative module documentation for detailed API)

Convert utilities

The convert module (requires convert feature) provides model conversion utilities:
#[cfg(feature = "convert")]
use mlx_rs_core::convert;
Utilities for converting models from other frameworks to MLX format.

Example usage

use mlx_rs_core::{
    Generate, KVCache, DefaultSampler, load_tokenizer,
    initialize_rope, create_attention_mask,
};
use mlx_rs::{module::Module, Array};

// Load model and tokenizer
let model = YourModel::new();
let tokenizer = load_tokenizer("path/to/model")?;

// Encode prompt
let encoding = tokenizer.encode("Hello, world!", true)?;
let prompt = Array::from_slice(
    encoding.get_ids(),
    &[encoding.len() as i32]
);

// Generate tokens
let generator = Generate::builder()
    .tokenizer(tokenizer)
    .model(model)
    .prompt(prompt)
    .temp(0.7)
    .max_tokens(100)
    .build();

for response in generator {
    let response = response?;
    println!("Generated: {}", response.text);
}

Feature flags

  • convert - Enable model conversion utilities
[dependencies]
mlx-rs-core = { version = "0.1", features = ["convert"] }

Build docs developers (and LLMs) love