Speculative decoding uses a small “draft” model to generate candidate tokens, which are then verified in parallel by the larger “target” model. This can achieve 2-3x speedup with identical output quality.
How it works
Normal autoregressive generation is inherently sequential:
Step 1: Target → token₁ (200ms)
Step 2: Target → token₂ (200ms)
Step 3: Target → token₃ (200ms)
Total: 600ms for 3 tokens
Speculative decoding parallelizes verification:
Step 1: Draft → [token₁, token₂, token₃] (3 × 20ms = 60ms)
Step 2: Target verifies [token₁, token₂, token₃] in parallel (200ms)
Total: 260ms for 3 tokens → 2.3x speedup
The key insight: verifying multiple tokens in a single forward pass is nearly as fast as generating one token, since the compute is parallelized across the sequence dimension.
Architecture
From mlx-rs-core/src/speculative.rs:
mlx-rs-core/src/speculative.rs
/// Speculative decoding generator
///
/// Uses a draft model to speculate multiple tokens ahead, then verifies them
/// with the target model in a single forward pass.
pub struct SpeculativeGenerate<'a, M, D, C>
where
M: SpeculativeModel,
D: SpeculativeModel,
C: KeyValueCache + Default,
{
/// Target (large) model
target_model: &'a mut M,
/// Draft (small) model
draft_model: &'a mut D,
/// KV cache for target model
target_cache: &'a mut Vec<Option<C>>,
/// KV cache for draft model
draft_cache: &'a mut Vec<Option<C>>,
/// Number of draft tokens to generate per step
num_draft_tokens: usize,
/// Temperature for sampling
temperature: f32,
// ...
}
Algorithm
1. Generate draft tokens
mlx-rs-core/src/speculative.rs
/// Generate draft tokens using the draft model
fn generate_draft_tokens(&mut self, start_token: &Array) -> Result<Vec<Array>, Exception> {
let mut tokens = Vec::with_capacity(self.num_draft_tokens);
let mut current = start_token.clone();
for _ in 0..self.num_draft_tokens {
let input = current.index((.., NewAxis));
let logits = self.draft_model.forward_speculative(&input, self.draft_cache)?;
let token = self.sample(&logits)?;
let _ = async_eval([&token]);
tokens.push(token.clone());
current = token;
}
Ok(tokens)
}
2. Verify with target model
mlx-rs-core/src/speculative.rs
/// Verify draft tokens with target model
/// Returns (accepted_count, all_tokens, all_logprobs)
fn verify_draft_tokens(
&mut self,
input_tokens: &Array,
) -> Result<(usize, Vec<Array>, Vec<Array>), Exception> {
// Forward all tokens through target model at once
let logits = self.target_model.forward_speculative(input_tokens, self.target_cache)?;
let seq_len = input_tokens.shape()[1] as usize;
let mut tokens = Vec::with_capacity(seq_len);
let mut logprobs = Vec::with_capacity(seq_len);
// Sample from each position's logits
for i in 0..seq_len {
let pos_logits = logits.index((.., i as i32, ..));
let token = self.sample(&pos_logits)?;
// Compute log probabilities
let log_sum_exp = logsumexp_axis(&pos_logits, -1, true)?;
let lp = pos_logits.subtract(&log_sum_exp)?;
tokens.push(token);
logprobs.push(lp);
}
eval(&tokens)?;
Ok((seq_len, tokens, logprobs))
}
3. Accept or reject
mlx-rs-core/src/speculative.rs
// Compare draft tokens with target tokens to find acceptance
let mut accepted = 0;
for i in 0..draft_tokens.len() {
let draft_id = draft_tokens[i].item::<u32>();
let target_id = target_tokens[i].item::<u32>();
if draft_id == target_id {
accepted += 1;
// Queue accepted token
self.pending_tokens.push(SpeculativeToken {
token: draft_tokens[i].clone(),
logprobs: target_logprobs[i].clone(),
from_draft: true,
});
} else {
break; // Stop at first mismatch
}
}
// The token at position `accepted` is from target model (correction or next)
let final_token = target_tokens[accepted].clone();
Usage example
use mlx_rs_core::{SpeculativeGenerate, KVCache};
// Load models
let mut target_model = load_model("Qwen3-7B")?; // Large, accurate
let mut draft_model = load_model("Qwen3-0.5B")?; // Small, fast
// Initialize caches
let mut target_cache: Vec<Option<KVCache>> = vec![None; target_model.num_layers()];
let mut draft_cache: Vec<Option<KVCache>> = vec![None; draft_model.num_layers()];
// Create speculative generator
let mut generator = SpeculativeGenerate::new(
&mut target_model,
&mut draft_model,
&mut target_cache,
&mut draft_cache,
num_draft_tokens: 4, // Generate 4 draft tokens per iteration
temperature: 0.7,
prompt: &prompt_array,
);
// Generate tokens
for result in generator {
let spec_token = result?;
let id = spec_token.token.item::<u32>();
if spec_token.from_draft {
// Token was accepted from draft model
print!("✓{}", tokenizer.decode(&[id], true)?);
} else {
// Token was corrected by target model
print!("✗{}", tokenizer.decode(&[id], true)?);
}
}
Result struct
mlx-rs-core/src/speculative.rs
/// Result of speculative decoding step
pub struct SpeculativeToken {
/// The generated token
pub token: Array,
/// Log probabilities
pub logprobs: Array,
/// Whether this token was accepted from draft model (true) or from target model (false)
pub from_draft: bool,
}
The from_draft flag indicates acceptance:
true: Draft token matched target prediction (accepted)
false: Draft token differed, using target correction
Theoretical speedup depends on:
- Draft model speed (d)
- Target model speed (t)
- Acceptance rate (α)
- Number of draft tokens (k)
speedup = 1 / ((k × d + t) / (1 + k × α))
Example with realistic values:
- Target model: 200ms/token
- Draft model: 20ms/token (10x faster)
- k = 4 draft tokens
- α = 0.7 acceptance rate
Normal: 4 tokens × 200ms = 800ms
Speculative: (4 × 20ms + 200ms) = 280ms
Speedup: 800 / 280 = 2.86x
Acceptance rate
Acceptance rate determines speedup:
| Acceptance Rate | Expected Speedup (k=4) |
|---|
| 100% (perfect) | 3.5x |
| 80% | 3.0x |
| 60% | 2.4x |
| 40% | 1.8x |
| 20% | 1.3x |
Acceptance rate is highest when:
- Draft and target models are from the same family (e.g., both Qwen)
- Task is predictable (factual QA, formatting)
- Temperature is low (greedy or near-greedy decoding)
When to use speculative decoding
Good use cases:
- Interactive applications where latency matters
- High-quality draft model available (same family as target)
- Batch size = 1 (single user)
- Predictable generation (factual, structured output)
- Low temperature (deterministic generation)
Not recommended for:
- Batch size > 1 (draft overhead multiplies)
- Creative generation (low acceptance rate)
- High temperature (random sampling reduces acceptance)
- No suitable small model available
- Draft model is < 5x faster than target
Model selection
Draft model requirements
- Speed: At least 5x faster than target (ideally 10x+)
- Architecture: Same family as target for higher acceptance
- Size: 5-20% of target model parameters
- Vocabulary: Must match target model exactly
Recommended pairs
| Target Model | Draft Model | Speed Ratio | Expected Acceptance |
|---|
| Qwen3-7B | Qwen3-0.5B | 12x | 70-80% |
| Qwen3-14B | Qwen3-1.5B | 10x | 70-75% |
| GLM4-9B | GLM4-1B | 8x | 65-75% |
| Mistral-7B | TinyLlama-1.1B | 6x | 50-60% |
Mismatched model families (e.g., Qwen draft + Llama target) have significantly lower acceptance rates (30-40%), reducing speedup to ~1.5x.
Tuning num_draft_tokens
More draft tokens = more potential parallelism, but also more compute if rejected:
| k (draft tokens) | Best For | Trade-off |
|---|
| 2 | Conservative, low acceptance rate | Minimal wasted work |
| 4 | Recommended default | Balanced |
| 8 | High acceptance rate (>75%) | More wasted work on rejection |
| 16 | Very high acceptance (>85%) | Risky, large overhead |
// Conservative: safe for creative generation
let generator = SpeculativeGenerate::new(
...,
num_draft_tokens: 2,
...
);
// Aggressive: maximize speedup for factual QA
let generator = SpeculativeGenerate::new(
...,
num_draft_tokens: 8,
...
);
Implementation details
State machine
mlx-rs-core/src/speculative.rs
enum SpeculativeState<'a> {
/// Initial state - need to process prompt
Prefill { prompt: &'a Array },
/// Main generation loop
Generate { last_token: Array },
}
The iterator alternates between:
- Prefill: Process prompt through both models once
- Generate: Draft → Verify → Accept/Reject loop
Pending tokens
When multiple draft tokens are accepted, they’re queued:
mlx-rs-core/src/speculative.rs
/// Token count
token_count: usize,
/// Pending tokens from verification (when draft tokens are accepted)
pending_tokens: Vec<SpeculativeToken>,
The iterator yields one token at a time, draining the queue before generating new drafts.
Cache synchronization
Both target and draft models maintain separate KV caches:
- Draft cache: Stores speculative keys/values (may be discarded)
- Target cache: Stores verified keys/values (permanent)
On rejection, draft cache must be trimmed to match acceptance point.
Debugging
Monitor acceptance rate during generation:
let mut total_tokens = 0;
let mut accepted_tokens = 0;
for result in generator {
let spec_token = result?;
total_tokens += 1;
if spec_token.from_draft {
accepted_tokens += 1;
}
if total_tokens % 100 == 0 {
let acceptance_rate = accepted_tokens as f32 / total_tokens as f32;
eprintln!("Acceptance rate: {:.1}%", acceptance_rate * 100.0);
}
}
If acceptance rate is < 40%, speculative decoding may be slower than normal generation. Consider using a better draft model or disabling speculation.
Limitations
- Batch size: Current implementation supports batch_size=1 only
- Memory: Requires loading two models simultaneously
- Complexity: More code paths, harder to debug
- Temperature sensitivity: High temperature reduces effectiveness
Future enhancements
- Multi-draft verification (verify multiple draft sequences)
- Adaptive k (adjust num_draft_tokens based on acceptance rate)
- Tree-based speculation (generate multiple branches)
- Batch support (speculate for multiple inputs)
References