Skip to main content
KV cache (key-value cache) stores attention keys and values from previous tokens during autoregressive generation, eliminating redundant computation. OminiX-MLX implements two cache strategies optimized for Apple Silicon.

Overview

During autoregressive generation, each new token attends to all previous tokens. Without caching, the model would recompute keys/values for all past tokens at every step:
Step 1: Q₁ attends to K₁, V₁
Step 2: Q₂ attends to K₁, K₂, V₁, V₂  ← Recomputes K₁, V₁
Step 3: Q₃ attends to K₁, K₂, K₃, V₁, V₂, V₃  ← Recomputes K₁, K₂, V₁, V₂
With KV cache:
Step 1: Q₁ attends to K₁, V₁ → Cache (K₁, V₁)
Step 2: Q₂ attends to [K₁, K₂], [V₁, V₂] → Update cache
Step 3: Q₃ attends to [K₁, K₂, K₃], [V₁, V₂, V₃] → Update cache
This reduces generation time from O(n²) to O(n) in sequence length.

Cache implementations

KeyValueCache trait

All caches implement the KeyValueCache trait from mlx-rs-core/src/cache.rs:
mlx-rs-core/src/cache.rs
pub trait KeyValueCache {
    /// Returns the current offset (number of tokens in cache)
    fn offset(&self) -> i32;

    /// Returns the maximum cache size (for sliding window), if any
    fn max_size(&self) -> Option<i32>;

    /// Update cache with new keys/values and return full cache contents
    fn update_and_fetch(&mut self, keys: Array, values: Array) 
        -> Result<(Array, Array), Exception>;
}

ConcatKeyValueCache

Simple concatenation-based cache. Best for short sequences or when memory is not a constraint.
mlx-rs-core/src/cache.rs
/// Simple concatenation-based KV cache
#[derive(Debug, Clone, Default)]
pub struct ConcatKeyValueCache {
    keys: Option<Array>,
    values: Option<Array>,
    offset: i32,
}

impl ConcatKeyValueCache {
    pub fn new() -> Self {
        Self::default()
    }
}

impl KeyValueCache for ConcatKeyValueCache {
    fn update_and_fetch(&mut self, keys: Array, values: Array) 
        -> Result<(Array, Array), Exception> {
        match (self.keys.take(), self.values.take()) {
            (Some(k), Some(v)) => {
                // Concatenate along sequence dimension (axis=-2)
                self.keys = Some(concatenate_axis(&[k, keys], -2)?);
                self.values = Some(concatenate_axis(&[v, values], -2)?);
            }
            _ => {
                // First token: just store
                self.keys = Some(keys);
                self.values = Some(values);
            }
        }
        
        let shape = self.keys.as_ref().expect("Keys cannot be None").shape();
        self.offset = shape[shape.len() - 2];

        Ok((
            self.keys.clone().expect("Keys cannot be None"),
            self.values.clone().expect("Values cannot be None"),
        ))
    }
}
Characteristics:
  • Allocates new memory on every token
  • Simple and reliable
  • Good for sequences < 512 tokens
  • Default in most model examples

KVCache (step-based pre-allocation)

Pre-allocates buffers in steps of 256 tokens, using in-place updates. Matches Python mlx-lm implementation.
mlx-rs-core/src/cache.rs
/// Step-based KV Cache with pre-allocation (matches Python mlx-lm KVCache)
///
/// This cache pre-allocates buffers in steps of 256 tokens and uses in-place
/// slice updates, avoiding expensive concatenation on every token.
#[derive(Debug, Clone)]
pub struct KVCache {
    keys: Option<Array>,
    values: Option<Array>,
    offset: i32,
    step: i32,  // Default: 256
}

impl KVCache {
    pub fn new() -> Self {
        Self::with_step(256)
    }

    pub fn with_step(step: i32) -> Self {
        Self {
            keys: None,
            values: None,
            offset: 0,
            step,
        }
    }
}
Update logic:
mlx-rs-core/src/cache.rs
fn update_and_fetch(&mut self, keys: Array, values: Array) 
    -> Result<(Array, Array), Exception> {
    let prev = self.offset;
    let num_new = keys.shape()[2];

    // Check if we need to grow the buffer
    let needs_grow = match &self.keys {
        None => true,
        Some(k) => (prev + num_new) > k.shape()[2],
    };

    if needs_grow {
        // Pre-allocate in multiples of step size (256)
        let n_steps = (self.step + num_new - 1) / self.step;
        let new_size = n_steps * self.step;
        
        let new_k = zeros_dtype(k_shape, k_dtype)?;
        let new_v = zeros_dtype(v_shape, v_dtype)?;
        
        // Concatenate old and new buffers
        self.keys = Some(concatenate_axis(&[old_k, new_k], 2)?);
        self.values = Some(concatenate_axis(&[old_v, new_v], 2)?);
    }

    self.offset += num_new;

    // In-place update using slice indexing
    let k = self.keys.as_mut().unwrap();
    let v = self.values.as_mut().unwrap();
    k.index_mut((Ellipsis, prev..self.offset, ..), &keys);
    v.index_mut((Ellipsis, prev..self.offset, ..), &values);

    Ok((
        k.index((Ellipsis, ..self.offset, ..)),
        v.index((Ellipsis, ..self.offset, ..)),
    ))
}
Characteristics:
  • Pre-allocates in chunks of 256 tokens
  • In-place updates via slice assignment
  • Reduces allocation overhead
  • Optimal for long sequences (512+ tokens)
  • Used in production LLM deployments

Usage in models

Basic usage

use mlx_rs_core::{ConcatKeyValueCache, KeyValueCache};

// Initialize cache (one per Transformer layer)
let num_layers = 32;
let mut cache: Vec<Option<ConcatKeyValueCache>> = 
    vec![None; num_layers];

// First token (prefill)
let (k_cached, v_cached) = match &mut cache[layer_idx] {
    None => {
        let mut c = ConcatKeyValueCache::new();
        let result = c.update_and_fetch(keys, values)?;
        cache[layer_idx] = Some(c);
        result
    }
    Some(c) => c.update_and_fetch(keys, values)?,
};

// Use cached keys/values in attention
let output = scaled_dot_product_attention(
    queries, k_cached, v_cached, None, scale, Some(SdpaMask::Causal)
)?;

With KVCache (pre-allocation)

use mlx_rs_core::{KVCache, KeyValueCache};

// Use pre-allocating cache for long sequences
let mut cache: Vec<Option<KVCache>> = vec![None; num_layers];

// Rest is identical - trait abstraction handles the difference
let (k_cached, v_cached) = match &mut cache[layer_idx] {
    None => {
        let mut c = KVCache::new();  // step=256 default
        let result = c.update_and_fetch(keys, values)?;
        cache[layer_idx] = Some(c);
        result
    }
    Some(c) => c.update_and_fetch(keys, values)?,
};

In generation loop

Example from qwen3-mlx:
use qwen3_mlx::{load_model, Generate, ConcatKeyValueCache};

let mut model = load_model("./models/Qwen3-4B")?;
let mut cache = Vec::new();  // Auto-initialized

let generator = Generate::<ConcatKeyValueCache>::new(
    &mut model, 
    &mut cache, 
    0.7,  // temperature
    &prompt_tokens
);

for token in generator.take(100) {
    let token = token?;
    print!("{}", tokenizer.decode(&[token.item::<u32>()], true)?);
}

Cache shapes

Standard attention

Keys and values have shape:
[batch_size, num_kv_heads, seq_length, head_dim]
Example for Qwen3-4B:
  • batch_size: 1 (typical inference)
  • num_kv_heads: 8 (GQA - Grouped Query Attention)
  • seq_length: Grows with each token
  • head_dim: 128

Grouped Query Attention (GQA)

Many modern models use GQA where K/V heads < Q heads:
ModelQ HeadsKV HeadsRatio
Qwen3-4B3284:1
Moxin-7B3284:1
GLM4-9B1682:1
This reduces KV cache memory by 2-4x with minimal quality loss.

Memory usage

For a single layer with BF16 precision:
memory_per_layer = 2 × batch_size × num_kv_heads × seq_length × head_dim × 2 bytes
Example for Qwen3-4B at 2048 tokens:
= 2 × 1 × 8 × 2048 × 128 × 2 bytes
= 8,388,608 bytes = 8 MB per layer

Total (32 layers): 256 MB
At 4096 tokens: 512 MB At 8192 tokens: 1 GB
Long context models (32k+ tokens) can use 4-8 GB just for KV cache. Consider sliding window attention or other optimizations for very long sequences.

Optimization tips

Choose the right cache type

Use ConcatKeyValueCache for:
  • Short sequences (< 512 tokens)
  • Chat applications (typical turns are short)
  • Prototyping and development
Use KVCache for:
  • Long sequences (512+ tokens)
  • Production deployments
  • Document summarization
  • Code generation

Adjust step size

Customize pre-allocation step size based on your use case:
// Small step for interactive chat (less wasted memory)
let cache = KVCache::with_step(128);

// Large step for long-form generation (fewer allocations)
let cache = KVCache::with_step(512);

Batch processing

Reuse cache across multiple generations:
// Clear cache between generations
cache.clear();

// Or keep cache for continued generation
// (e.g., multi-turn conversation)

GQA benefits

Models with GQA use 2-4x less KV cache memory:
Standard MHA (32 heads):     32 × seq_len × head_dim × 2 × 2 bytes
GQA (32Q / 8KV heads):        8 × seq_len × head_dim × 2 × 2 bytes
Savings: 4x reduction

Performance benchmarks

Measured on Apple M3 Max with Qwen3-4B, 2048 token generation:
Cache TypeFirst TokenTokens/secMemory
No cache50ms8 tok/s8 GB
ConcatKeyValueCache55ms42 tok/s8.3 GB
KVCache (step=256)56ms45 tok/s8.3 GB
Cache overhead is minimal (< 5ms per token). The 5x+ speedup compared to no caching comes from avoiding recomputation of all prior tokens.

Sliding window attention

For fixed-context models (e.g., Mistral with 4096 token window):
// Future enhancement - not yet implemented
pub struct SlidingWindowCache {
    window_size: i32,
    // ...
}

impl KeyValueCache for SlidingWindowCache {
    fn max_size(&self) -> Option<i32> {
        Some(self.window_size)
    }
    // Automatically drops oldest tokens when window is full
}

References

Build docs developers (and LLMs) love