Skip to main content
Speculative decoding uses a small “draft” model to generate candidate tokens, which are then verified in parallel by the larger “target” model. This can achieve 2-3x speedup with identical output quality.

How it works

Normal autoregressive generation is inherently sequential:
Step 1: Target → token₁ (200ms)
Step 2: Target → token₂ (200ms)  
Step 3: Target → token₃ (200ms)
Total: 600ms for 3 tokens
Speculative decoding parallelizes verification:
Step 1: Draft → [token₁, token₂, token₃] (3 × 20ms = 60ms)
Step 2: Target verifies [token₁, token₂, token₃] in parallel (200ms)
Total: 260ms for 3 tokens → 2.3x speedup
The key insight: verifying multiple tokens in a single forward pass is nearly as fast as generating one token, since the compute is parallelized across the sequence dimension.

Architecture

From mlx-rs-core/src/speculative.rs:
mlx-rs-core/src/speculative.rs
/// Speculative decoding generator
///
/// Uses a draft model to speculate multiple tokens ahead, then verifies them
/// with the target model in a single forward pass.
pub struct SpeculativeGenerate<'a, M, D, C>
where
    M: SpeculativeModel,
    D: SpeculativeModel,
    C: KeyValueCache + Default,
{
    /// Target (large) model
    target_model: &'a mut M,
    /// Draft (small) model
    draft_model: &'a mut D,
    /// KV cache for target model
    target_cache: &'a mut Vec<Option<C>>,
    /// KV cache for draft model
    draft_cache: &'a mut Vec<Option<C>>,
    /// Number of draft tokens to generate per step
    num_draft_tokens: usize,
    /// Temperature for sampling
    temperature: f32,
    // ...
}

Algorithm

1. Generate draft tokens

mlx-rs-core/src/speculative.rs
/// Generate draft tokens using the draft model
fn generate_draft_tokens(&mut self, start_token: &Array) -> Result<Vec<Array>, Exception> {
    let mut tokens = Vec::with_capacity(self.num_draft_tokens);
    let mut current = start_token.clone();

    for _ in 0..self.num_draft_tokens {
        let input = current.index((.., NewAxis));
        let logits = self.draft_model.forward_speculative(&input, self.draft_cache)?;
        let token = self.sample(&logits)?;
        let _ = async_eval([&token]);
        tokens.push(token.clone());
        current = token;
    }

    Ok(tokens)
}

2. Verify with target model

mlx-rs-core/src/speculative.rs
/// Verify draft tokens with target model
/// Returns (accepted_count, all_tokens, all_logprobs)
fn verify_draft_tokens(
    &mut self,
    input_tokens: &Array,
) -> Result<(usize, Vec<Array>, Vec<Array>), Exception> {
    // Forward all tokens through target model at once
    let logits = self.target_model.forward_speculative(input_tokens, self.target_cache)?;

    let seq_len = input_tokens.shape()[1] as usize;
    let mut tokens = Vec::with_capacity(seq_len);
    let mut logprobs = Vec::with_capacity(seq_len);

    // Sample from each position's logits
    for i in 0..seq_len {
        let pos_logits = logits.index((.., i as i32, ..));
        let token = self.sample(&pos_logits)?;

        // Compute log probabilities
        let log_sum_exp = logsumexp_axis(&pos_logits, -1, true)?;
        let lp = pos_logits.subtract(&log_sum_exp)?;

        tokens.push(token);
        logprobs.push(lp);
    }

    eval(&tokens)?;
    Ok((seq_len, tokens, logprobs))
}

3. Accept or reject

mlx-rs-core/src/speculative.rs
// Compare draft tokens with target tokens to find acceptance
let mut accepted = 0;
for i in 0..draft_tokens.len() {
    let draft_id = draft_tokens[i].item::<u32>();
    let target_id = target_tokens[i].item::<u32>();

    if draft_id == target_id {
        accepted += 1;
        // Queue accepted token
        self.pending_tokens.push(SpeculativeToken {
            token: draft_tokens[i].clone(),
            logprobs: target_logprobs[i].clone(),
            from_draft: true,
        });
    } else {
        break;  // Stop at first mismatch
    }
}

// The token at position `accepted` is from target model (correction or next)
let final_token = target_tokens[accepted].clone();

Usage example

use mlx_rs_core::{SpeculativeGenerate, KVCache};

// Load models
let mut target_model = load_model("Qwen3-7B")?;   // Large, accurate
let mut draft_model = load_model("Qwen3-0.5B")?;  // Small, fast

// Initialize caches
let mut target_cache: Vec<Option<KVCache>> = vec![None; target_model.num_layers()];
let mut draft_cache: Vec<Option<KVCache>> = vec![None; draft_model.num_layers()];

// Create speculative generator
let mut generator = SpeculativeGenerate::new(
    &mut target_model,
    &mut draft_model,
    &mut target_cache,
    &mut draft_cache,
    num_draft_tokens: 4,  // Generate 4 draft tokens per iteration
    temperature: 0.7,
    prompt: &prompt_array,
);

// Generate tokens
for result in generator {
    let spec_token = result?;
    let id = spec_token.token.item::<u32>();
    
    if spec_token.from_draft {
        // Token was accepted from draft model
        print!("✓{}", tokenizer.decode(&[id], true)?);
    } else {
        // Token was corrected by target model
        print!("✗{}", tokenizer.decode(&[id], true)?);
    }
}

Result struct

mlx-rs-core/src/speculative.rs
/// Result of speculative decoding step
pub struct SpeculativeToken {
    /// The generated token
    pub token: Array,
    /// Log probabilities
    pub logprobs: Array,
    /// Whether this token was accepted from draft model (true) or from target model (false)
    pub from_draft: bool,
}
The from_draft flag indicates acceptance:
  • true: Draft token matched target prediction (accepted)
  • false: Draft token differed, using target correction

Performance considerations

Speedup formula

Theoretical speedup depends on:
  • Draft model speed (d)
  • Target model speed (t)
  • Acceptance rate (α)
  • Number of draft tokens (k)
speedup = 1 / ((k × d + t) / (1 + k × α))
Example with realistic values:
  • Target model: 200ms/token
  • Draft model: 20ms/token (10x faster)
  • k = 4 draft tokens
  • α = 0.7 acceptance rate
Normal: 4 tokens × 200ms = 800ms
Speculative: (4 × 20ms + 200ms) = 280ms
Speedup: 800 / 280 = 2.86x

Acceptance rate

Acceptance rate determines speedup:
Acceptance RateExpected Speedup (k=4)
100% (perfect)3.5x
80%3.0x
60%2.4x
40%1.8x
20%1.3x
Acceptance rate is highest when:
  • Draft and target models are from the same family (e.g., both Qwen)
  • Task is predictable (factual QA, formatting)
  • Temperature is low (greedy or near-greedy decoding)

When to use speculative decoding

Good use cases:
  • Interactive applications where latency matters
  • High-quality draft model available (same family as target)
  • Batch size = 1 (single user)
  • Predictable generation (factual, structured output)
  • Low temperature (deterministic generation)
Not recommended for:
  • Batch size > 1 (draft overhead multiplies)
  • Creative generation (low acceptance rate)
  • High temperature (random sampling reduces acceptance)
  • No suitable small model available
  • Draft model is < 5x faster than target

Model selection

Draft model requirements

  1. Speed: At least 5x faster than target (ideally 10x+)
  2. Architecture: Same family as target for higher acceptance
  3. Size: 5-20% of target model parameters
  4. Vocabulary: Must match target model exactly
Target ModelDraft ModelSpeed RatioExpected Acceptance
Qwen3-7BQwen3-0.5B12x70-80%
Qwen3-14BQwen3-1.5B10x70-75%
GLM4-9BGLM4-1B8x65-75%
Mistral-7BTinyLlama-1.1B6x50-60%
Mismatched model families (e.g., Qwen draft + Llama target) have significantly lower acceptance rates (30-40%), reducing speedup to ~1.5x.

Tuning num_draft_tokens

More draft tokens = more potential parallelism, but also more compute if rejected:
k (draft tokens)Best ForTrade-off
2Conservative, low acceptance rateMinimal wasted work
4Recommended defaultBalanced
8High acceptance rate (>75%)More wasted work on rejection
16Very high acceptance (>85%)Risky, large overhead
// Conservative: safe for creative generation
let generator = SpeculativeGenerate::new(
    ...,
    num_draft_tokens: 2,
    ...
);

// Aggressive: maximize speedup for factual QA
let generator = SpeculativeGenerate::new(
    ...,
    num_draft_tokens: 8,
    ...
);

Implementation details

State machine

mlx-rs-core/src/speculative.rs
enum SpeculativeState<'a> {
    /// Initial state - need to process prompt
    Prefill { prompt: &'a Array },
    /// Main generation loop
    Generate { last_token: Array },
}
The iterator alternates between:
  1. Prefill: Process prompt through both models once
  2. Generate: Draft → Verify → Accept/Reject loop

Pending tokens

When multiple draft tokens are accepted, they’re queued:
mlx-rs-core/src/speculative.rs
/// Token count
token_count: usize,
/// Pending tokens from verification (when draft tokens are accepted)
pending_tokens: Vec<SpeculativeToken>,
The iterator yields one token at a time, draining the queue before generating new drafts.

Cache synchronization

Both target and draft models maintain separate KV caches:
  • Draft cache: Stores speculative keys/values (may be discarded)
  • Target cache: Stores verified keys/values (permanent)
On rejection, draft cache must be trimmed to match acceptance point.

Debugging

Monitor acceptance rate during generation:
let mut total_tokens = 0;
let mut accepted_tokens = 0;

for result in generator {
    let spec_token = result?;
    total_tokens += 1;
    
    if spec_token.from_draft {
        accepted_tokens += 1;
    }
    
    if total_tokens % 100 == 0 {
        let acceptance_rate = accepted_tokens as f32 / total_tokens as f32;
        eprintln!("Acceptance rate: {:.1}%", acceptance_rate * 100.0);
    }
}
If acceptance rate is < 40%, speculative decoding may be slower than normal generation. Consider using a better draft model or disabling speculation.

Limitations

  1. Batch size: Current implementation supports batch_size=1 only
  2. Memory: Requires loading two models simultaneously
  3. Complexity: More code paths, harder to debug
  4. Temperature sensitivity: High temperature reduces effectiveness

Future enhancements

  • Multi-draft verification (verify multiple draft sequences)
  • Adaptive k (adjust num_draft_tokens based on acceptance rate)
  • Tree-based speculation (generate multiple branches)
  • Batch support (speculate for multiple inputs)

References

Build docs developers (and LLMs) love