Skip to main content

What is KV Cache?

Key-Value (KV) cache is a critical optimization technique for autoregressive transformer models. During text generation, each token attends to all previous tokens in the sequence. Without caching, this would require recomputing attention keys and values for all previous tokens at each step. The KV cache stores these computed attention keys and values, allowing the model to:
  • Compute attention only for the new token(s)
  • Reuse cached keys and values from previous tokens
  • Dramatically reduce computational cost from O(n²) to O(n) per token

Performance Impact

Without KV cache:
  • First token: 100ms
  • Second token: 200ms (recomputes first token)
  • Third token: 300ms (recomputes first two tokens)
  • Total for 100 tokens: ~500 seconds
With KV cache:
  • First token: 100ms (fills cache)
  • Each subsequent token: ~100ms (uses cache)
  • Total for 100 tokens: ~10 seconds (50x faster!)

Why KV Cache Matters

For a transformer with:
  • L layers
  • H attention heads
  • D head dimension
  • S sequence length
  • B batch size
KV cache memory usage:
Memory = 2 × L × B × H × S × D × sizeof(dtype)
Example for Llama-2-7B (L=32, H=32, D=128) with batch=1, sequence=2048, fp16:
2 × 32 × 1 × 32 × 2048 × 128 × 2 bytes = 1 GB per sequence
This is why KV cache management is crucial for:
  • Memory efficiency
  • Batch size scaling
  • Long context support
  • Multi-sequence generation

KV Cache Implementation

ONNX Runtime GenAI provides several KV cache implementations (from src/models/kv_cache.h):

Cache Types

Standard implementation where past and present are separate tensors.Usage:
  • Most model architectures
  • Default choice for CPU and most execution providers
Implementation (from src/models/kv_cache.h:64):
struct DefaultKeyValueCache : KeyValueCache {
  void Add() override;              // Add cache I/O to state
  void Update(DeviceSpan<int32_t> beam_indices, 
              int total_length) override;  // Move present→past
  void RewindTo(size_t index) override;    // For continuous decoding
  
  std::vector<std::unique_ptr<OrtValue>> pasts_, presents_;
  std::array<int64_t, 4> shape_;  // [batch, heads, seq, dim]
};
Keys and values are combined in a single tensor per layer.Usage:
  • Models with past_names config (combined KV format)
  • Some optimized model exports
Shape: [batch, 2, heads, seq, dim] where dimension 1 holds [keys, values]Implementation (from src/models/kv_cache.h:31):
struct CombinedKeyValueCache : KeyValueCache {
  std::array<int64_t, 5> shape_;  // Extra dimension for K/V
  std::vector<std::unique_ptr<OrtValue>> pasts_, presents_;
};
KV cache is managed internally by the model/execution provider.Usage:
  • QNN execution provider
  • Models with stateful sessions
  • Custom execution providers with internal caching
Implementation (from src/models/kv_cache.h:130):
struct ModelManagedKeyValueCache : KeyValueCache {
  void Add() override { /* no-op */ }
  void Update(DeviceSpan<int32_t> beam_indices, int total_length) override;
  void RewindTo(size_t index) override { /* not supported */ }
};

Cache Shape Formats

From the configuration and implementation: Default (Separate K/V):
Shape: [batch_size, num_heads, sequence_length, head_dim]
Inputs:  past_key_values.0.key, past_key_values.0.value, ...
Outputs: present.0.key, present.0.value, ...
Combined (Single Tensor):
Shape: [batch_size, 2, num_heads, sequence_length, head_dim]
Inputs:  past_0, past_1, ...
Outputs: present_0, present_1, ...
Grouped Query Attention (GQA):
Shape: [batch_size, num_kv_heads, sequence_length, head_dim]
# where num_kv_heads < num_attention_heads

Cache Management Strategies

Update Strategy

After each generation step (from src/models/kv_cache.cpp):
void DefaultKeyValueCache::Update(DeviceSpan<int32_t> beam_indices, 
                                  int total_length) {
  if (is_first_update_) {
    // First update: presents are already filled by model
    is_first_update_ = false;
  } else {
    // Subsequent updates: copy present → past
    for (int i = 0; i < layer_count_; i++) {
      if (beam_indices.empty()) {
        // No beam search: simple swap
        std::swap(pasts_[i], presents_[i]);
      } else {
        // Beam search: reorder based on beam_indices
        PickPastState(beam_indices, i);
      }
    }
  }
  
  // Update shape for next iteration
  shape_[2] = total_length;  // sequence_length dimension
}

Beam Search Reordering

Beam search requires reordering cache entries when beams are selected:
template <typename T>
void DefaultKeyValueCache::PickPastState(DeviceSpan<int32_t> beam_indices, 
                                         int index) {
  // beam_indices: [batch_beam_size] mapping new→old beam positions
  
  // Allocate temporary buffer
  auto new_past = OrtValue::CreateTensor(
    Allocator(), shape_, type_);
  
  // Reorder: new_past[i] = old_past[beam_indices[i]]
  for (int i = 0; i < batch_beam_size; i++) {
    int old_index = beam_indices[i];
    CopyCache(new_past, i, *pasts_[index], old_index);
  }
  
  pasts_[index] = std::move(new_past);
}

Shared Buffer Optimization

For CUDA with greedy search, enable buffer sharing to reduce memory allocations (from src/config.h:308):
{
  "search": {
    "past_present_share_buffer": true
  }
}
Requirements (from src/generators.h:96):
  • CUDA execution provider
  • num_beams=1 (greedy search) OR Whisper model
  • Allocates cache to max_length upfront
Benefits:
  • Eliminates memory allocations during generation
  • Enables CUDA graph capture
  • Reduces latency per token
Implementation:
bool GeneratorParams::IsPastPresentShareBufferEnabled(const std::string& model_type) const {
  return search.past_present_share_buffer && 
         (search.num_beams == 1 || model_type == "whisper");
}

Memory Optimization

Per-Layer Cache Shapes

Models with alternating attention patterns (e.g., sliding window) can have different cache shapes per layer:
// From src/models/kv_cache.h:101
std::vector<std::array<int64_t, 4>> layer_shapes_;  // Per-layer shapes
Example configuration:
{
  "decoder": {
    "sliding_window": {
      "window_size": 4096,
      "layers": [0, 2, 4, 6]  // Only these layers use sliding window
    }
  }
}

Cache Pruning for Long Contexts

For sequences exceeding context length, older cache entries can be pruned:
void RewindTo(size_t new_length) {
  // Truncate cache to new_length
  for (auto& past : pasts_) {
    TruncateTensor(past, new_length);
  }
  shape_[2] = new_length;  // Update sequence dimension
}

Sliding Window Attention

From src/config.h:228, models can use sliding window to limit cache size:
{
  "decoder": {
    "sliding_window": {
      "window_size": 4096,
      "pad_value": -65504,  // Value for inactive tokens (fp16 lowest)
      "alignment": "right",  // or "left"
      "slide_key_value_cache": true,
      "slide_inputs": true
    }
  }
}
This limits memory to window_size tokens regardless of total sequence length.

Device-Specific Considerations

CUDA

Advantages:
  • Fast GPU memory access
  • CUDA graph capture with shared buffers
  • Efficient beam search reordering with kernels
Configuration:
{
  "search": {
    "past_present_share_buffer": true
  },
  "decoder": {
    "session_options": {
      "provider_options": [{
        "cuda": {
          "enable_cuda_graph": "1"
        }
      }]
    }
  }
}

DirectML

Considerations:
  • Inputs may be on CPU while cache is on GPU
  • Separate device interfaces: p_device_inputs_ vs p_device_kvcache_
// From src/models/model.h:168
DeviceInterface* p_device_inputs_;   // May be CPU
DeviceInterface* p_device_kvcache_;  // Always GPU for DirectML

CPU

Optimization:
  • Use fp16 or quantized models to reduce memory bandwidth
  • Consider smaller batch sizes
  • Cache fits in system RAM (larger than GPU VRAM)

WebGPU

Limitations:
  • Inputs must be on CPU
  • Cache on GPU
  • Cross-device copies required

Continuous Decoding and Rewinding

The KV cache supports rewinding to enable speculative decoding and alternative paths:
model = og.Model('model_path')
params = og.GeneratorParams(model)
generator = og.Generator(model, params)

# Generate initial tokens
generator.append_tokens(input_tokens)
for _ in range(10):
    generator.generate_next_token()

checkpoint_length = generator.token_count()

# Try speculative path
for _ in range(5):
    generator.generate_next_token()

# Rewind and try different path
generator.rewind_to(checkpoint_length)
for _ in range(5):
    generator.generate_next_token()
Implementation (from src/models/kv_cache.h:76):
void DefaultKeyValueCache::RewindTo(size_t index) {
  // Truncate past tensors to index
  RewindPastTensorsTo<T>(index);
  
  // Update shape
  shape_[2] = index;
}

Cross-Attention Cache

Encoder-decoder models use a separate CrossCache for encoder outputs (from src/models/kv_cache.h:109):
struct CrossCache {
  CrossCache(State& state, int sequence_length);
  
  void AddOutputs(State& state);  // Add to encoder outputs
  void AddInputs(State& state);   // Add to decoder inputs
  
  std::vector<std::unique_ptr<OrtValue>> values_;  // Fixed, not updated
};
Unlike self-attention cache:
  • Created once during encoding
  • Never updated during decoding
  • Shared across all decoder steps

Performance Implications

Memory vs. Compute Trade-off

With KV Cache:
  • ✅ Much faster generation (50-100x)
  • ❌ Significant memory usage (1-4 GB per sequence)
  • ❌ Limits batch size and context length
Without KV Cache:
  • ❌ Extremely slow (quadratic in sequence length)
  • ✅ Minimal memory overhead
  • ✅ Not practical for generation

Batch Size Impact

Maximum batch size is often limited by KV cache memory:
# For Llama-2-7B on 24GB GPU:
# ~1 GB per sequence for 2048 context
# ~4 GB model weights
# → Max batch size ≈ 16-20

params.set_search_options(
    max_length=2048,
    batch_size=16  # Limited by KV cache memory
)

Context Length Scaling

KV cache memory scales linearly with context length:
# Context length impact on memory (Llama-2-7B, batch=1, fp16):
#   1K tokens:   500 MB
#   2K tokens:   1.0 GB
#   4K tokens:   2.0 GB
#   8K tokens:   4.0 GB  
#  16K tokens:   8.0 GB
#  32K tokens:  16.0 GB

Debugging KV Cache

Enable logging to track cache operations:
import onnxruntime_genai as og

og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True)

model = og.Model('model_path')
# Cache operations will be logged
Look for log entries like:
Adding KV cache input: past_key_values.0.key shape=[1,32,0,128]
KV cache update: present.0.key shape=[1,32,15,128]

Best Practices

1

Enable Shared Buffers

For CUDA with greedy search, enable past_present_share_buffer for best performance.
2

Monitor Memory

Track KV cache memory usage, especially for large batch sizes or long contexts.
3

Use Appropriate Precision

FP16 halves KV cache memory vs FP32 with minimal quality impact.
4

Consider Sliding Window

For very long contexts, sliding window attention limits cache growth.
5

Profile Beam Search

Beam search has additional overhead for cache reordering. Use greedy search if beam search isn’t needed.

Next Steps

Generation

Learn about search strategies and generation parameters

Models

Explore model configuration and optimization

Performance Tuning

Optimize inference performance

API Reference

Browse the complete API documentation

Build docs developers (and LLMs) love