KV cache (key-value cache) stores attention keys and values from previous tokens during autoregressive generation, eliminating redundant computation. OminiX-MLX implements two cache strategies optimized for Apple Silicon.
Overview
During autoregressive generation, each new token attends to all previous tokens. Without caching, the model would recompute keys/values for all past tokens at every step:
Step 1: Q₁ attends to K₁, V₁
Step 2: Q₂ attends to K₁, K₂, V₁, V₂ ← Recomputes K₁, V₁
Step 3: Q₃ attends to K₁, K₂, K₃, V₁, V₂, V₃ ← Recomputes K₁, K₂, V₁, V₂
With KV cache:
Step 1: Q₁ attends to K₁, V₁ → Cache (K₁, V₁)
Step 2: Q₂ attends to [K₁, K₂], [V₁, V₂] → Update cache
Step 3: Q₃ attends to [K₁, K₂, K₃], [V₁, V₂, V₃] → Update cache
This reduces generation time from O(n²) to O(n) in sequence length.
Cache implementations
KeyValueCache trait
All caches implement the KeyValueCache trait from mlx-rs-core/src/cache.rs:
pub trait KeyValueCache {
/// Returns the current offset (number of tokens in cache)
fn offset(&self) -> i32;
/// Returns the maximum cache size (for sliding window), if any
fn max_size(&self) -> Option<i32>;
/// Update cache with new keys/values and return full cache contents
fn update_and_fetch(&mut self, keys: Array, values: Array)
-> Result<(Array, Array), Exception>;
}
ConcatKeyValueCache
Simple concatenation-based cache. Best for short sequences or when memory is not a constraint.
/// Simple concatenation-based KV cache
#[derive(Debug, Clone, Default)]
pub struct ConcatKeyValueCache {
keys: Option<Array>,
values: Option<Array>,
offset: i32,
}
impl ConcatKeyValueCache {
pub fn new() -> Self {
Self::default()
}
}
impl KeyValueCache for ConcatKeyValueCache {
fn update_and_fetch(&mut self, keys: Array, values: Array)
-> Result<(Array, Array), Exception> {
match (self.keys.take(), self.values.take()) {
(Some(k), Some(v)) => {
// Concatenate along sequence dimension (axis=-2)
self.keys = Some(concatenate_axis(&[k, keys], -2)?);
self.values = Some(concatenate_axis(&[v, values], -2)?);
}
_ => {
// First token: just store
self.keys = Some(keys);
self.values = Some(values);
}
}
let shape = self.keys.as_ref().expect("Keys cannot be None").shape();
self.offset = shape[shape.len() - 2];
Ok((
self.keys.clone().expect("Keys cannot be None"),
self.values.clone().expect("Values cannot be None"),
))
}
}
Characteristics:
- Allocates new memory on every token
- Simple and reliable
- Good for sequences < 512 tokens
- Default in most model examples
KVCache (step-based pre-allocation)
Pre-allocates buffers in steps of 256 tokens, using in-place updates. Matches Python mlx-lm implementation.
/// Step-based KV Cache with pre-allocation (matches Python mlx-lm KVCache)
///
/// This cache pre-allocates buffers in steps of 256 tokens and uses in-place
/// slice updates, avoiding expensive concatenation on every token.
#[derive(Debug, Clone)]
pub struct KVCache {
keys: Option<Array>,
values: Option<Array>,
offset: i32,
step: i32, // Default: 256
}
impl KVCache {
pub fn new() -> Self {
Self::with_step(256)
}
pub fn with_step(step: i32) -> Self {
Self {
keys: None,
values: None,
offset: 0,
step,
}
}
}
Update logic:
fn update_and_fetch(&mut self, keys: Array, values: Array)
-> Result<(Array, Array), Exception> {
let prev = self.offset;
let num_new = keys.shape()[2];
// Check if we need to grow the buffer
let needs_grow = match &self.keys {
None => true,
Some(k) => (prev + num_new) > k.shape()[2],
};
if needs_grow {
// Pre-allocate in multiples of step size (256)
let n_steps = (self.step + num_new - 1) / self.step;
let new_size = n_steps * self.step;
let new_k = zeros_dtype(k_shape, k_dtype)?;
let new_v = zeros_dtype(v_shape, v_dtype)?;
// Concatenate old and new buffers
self.keys = Some(concatenate_axis(&[old_k, new_k], 2)?);
self.values = Some(concatenate_axis(&[old_v, new_v], 2)?);
}
self.offset += num_new;
// In-place update using slice indexing
let k = self.keys.as_mut().unwrap();
let v = self.values.as_mut().unwrap();
k.index_mut((Ellipsis, prev..self.offset, ..), &keys);
v.index_mut((Ellipsis, prev..self.offset, ..), &values);
Ok((
k.index((Ellipsis, ..self.offset, ..)),
v.index((Ellipsis, ..self.offset, ..)),
))
}
Characteristics:
- Pre-allocates in chunks of 256 tokens
- In-place updates via slice assignment
- Reduces allocation overhead
- Optimal for long sequences (512+ tokens)
- Used in production LLM deployments
Usage in models
Basic usage
use mlx_rs_core::{ConcatKeyValueCache, KeyValueCache};
// Initialize cache (one per Transformer layer)
let num_layers = 32;
let mut cache: Vec<Option<ConcatKeyValueCache>> =
vec![None; num_layers];
// First token (prefill)
let (k_cached, v_cached) = match &mut cache[layer_idx] {
None => {
let mut c = ConcatKeyValueCache::new();
let result = c.update_and_fetch(keys, values)?;
cache[layer_idx] = Some(c);
result
}
Some(c) => c.update_and_fetch(keys, values)?,
};
// Use cached keys/values in attention
let output = scaled_dot_product_attention(
queries, k_cached, v_cached, None, scale, Some(SdpaMask::Causal)
)?;
With KVCache (pre-allocation)
use mlx_rs_core::{KVCache, KeyValueCache};
// Use pre-allocating cache for long sequences
let mut cache: Vec<Option<KVCache>> = vec![None; num_layers];
// Rest is identical - trait abstraction handles the difference
let (k_cached, v_cached) = match &mut cache[layer_idx] {
None => {
let mut c = KVCache::new(); // step=256 default
let result = c.update_and_fetch(keys, values)?;
cache[layer_idx] = Some(c);
result
}
Some(c) => c.update_and_fetch(keys, values)?,
};
In generation loop
Example from qwen3-mlx:
use qwen3_mlx::{load_model, Generate, ConcatKeyValueCache};
let mut model = load_model("./models/Qwen3-4B")?;
let mut cache = Vec::new(); // Auto-initialized
let generator = Generate::<ConcatKeyValueCache>::new(
&mut model,
&mut cache,
0.7, // temperature
&prompt_tokens
);
for token in generator.take(100) {
let token = token?;
print!("{}", tokenizer.decode(&[token.item::<u32>()], true)?);
}
Cache shapes
Standard attention
Keys and values have shape:
[batch_size, num_kv_heads, seq_length, head_dim]
Example for Qwen3-4B:
- batch_size: 1 (typical inference)
- num_kv_heads: 8 (GQA - Grouped Query Attention)
- seq_length: Grows with each token
- head_dim: 128
Grouped Query Attention (GQA)
Many modern models use GQA where K/V heads < Q heads:
| Model | Q Heads | KV Heads | Ratio |
|---|
| Qwen3-4B | 32 | 8 | 4:1 |
| Moxin-7B | 32 | 8 | 4:1 |
| GLM4-9B | 16 | 8 | 2:1 |
This reduces KV cache memory by 2-4x with minimal quality loss.
Memory usage
For a single layer with BF16 precision:
memory_per_layer = 2 × batch_size × num_kv_heads × seq_length × head_dim × 2 bytes
Example for Qwen3-4B at 2048 tokens:
= 2 × 1 × 8 × 2048 × 128 × 2 bytes
= 8,388,608 bytes = 8 MB per layer
Total (32 layers): 256 MB
At 4096 tokens: 512 MB
At 8192 tokens: 1 GB
Long context models (32k+ tokens) can use 4-8 GB just for KV cache. Consider sliding window attention or other optimizations for very long sequences.
Optimization tips
Choose the right cache type
Use ConcatKeyValueCache for:
- Short sequences (< 512 tokens)
- Chat applications (typical turns are short)
- Prototyping and development
Use KVCache for:
- Long sequences (512+ tokens)
- Production deployments
- Document summarization
- Code generation
Adjust step size
Customize pre-allocation step size based on your use case:
// Small step for interactive chat (less wasted memory)
let cache = KVCache::with_step(128);
// Large step for long-form generation (fewer allocations)
let cache = KVCache::with_step(512);
Batch processing
Reuse cache across multiple generations:
// Clear cache between generations
cache.clear();
// Or keep cache for continued generation
// (e.g., multi-turn conversation)
GQA benefits
Models with GQA use 2-4x less KV cache memory:
Standard MHA (32 heads): 32 × seq_len × head_dim × 2 × 2 bytes
GQA (32Q / 8KV heads): 8 × seq_len × head_dim × 2 × 2 bytes
Savings: 4x reduction
Measured on Apple M3 Max with Qwen3-4B, 2048 token generation:
| Cache Type | First Token | Tokens/sec | Memory |
|---|
| No cache | 50ms | 8 tok/s | 8 GB |
| ConcatKeyValueCache | 55ms | 42 tok/s | 8.3 GB |
| KVCache (step=256) | 56ms | 45 tok/s | 8.3 GB |
Cache overhead is minimal (< 5ms per token). The 5x+ speedup compared to no caching comes from avoiding recomputation of all prior tokens.
Sliding window attention
For fixed-context models (e.g., Mistral with 4096 token window):
// Future enhancement - not yet implemented
pub struct SlidingWindowCache {
window_size: i32,
// ...
}
impl KeyValueCache for SlidingWindowCache {
fn max_size(&self) -> Option<i32> {
Some(self.window_size)
}
// Automatically drops oldest tokens when window is full
}
References