Skip to main content
The KV cache stores previously computed key-value pairs for reuse during generation, avoiding redundant calculations. TensorRT-LLM’s KV cache system supports reuse across requests, offloading to host memory, and uses techniques like prioritized eviction to maximize cache efficiency.

Overview

The KV cache is a pool of blocks that can hold KV state for a fixed number of tokens. Key features include:
  • Paged memory management with block-based allocation
  • Cross-request reuse via radix tree search
  • Prioritized eviction with configurable retention policies
  • Offloading to host memory for extended capacity
  • Multiple data types (FP16, BF16, FP8, INT8, NVFP4)
  • Variable attention windows and grouped query attention support

Architecture

Block-Based Memory

The KV cache divides memory into fixed-size blocks:
  • Each block holds KV state for a fixed number of tokens (configurable, must be power of 2)
  • Multiple layers are packed within a single block
  • Blocks are assigned to requests as needed
  • Separate pools for different attention window sizes and head counts
All layers in a pool must have the same number of heads and attention window size. Multiple pools are created automatically to support models with varying configurations.

Radix Tree for Reuse

Blocks are stored in a radix search tree as they are filled:
  • New requests search the tree for matching prefixes
  • Matched blocks are reused instead of recalculated
  • Blocks can be shared among multiple requests
  • Saves both memory and computation
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig

# Enable block reuse (default behavior)
kv_cache_config = KvCacheConfig(
    enable_block_reuse=True  # Cross-request reuse enabled
)

llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", kv_cache_config=kv_cache_config)

Configuration

Memory Allocation

Control how much GPU memory is allocated to KV cache:
from tensorrt_llm.llmapi import KvCacheConfig

kv_cache_config = KvCacheConfig(
    free_gpu_memory_fraction=0.9,  # Use 90% of free GPU memory (default)
    max_tokens=8192                # Or limit by token count
)
If both free_gpu_memory_fraction and max_tokens are set, the lesser of the two is used.

Data Type

Specify the KV cache data type for memory-performance tradeoffs:
from tensorrt_llm.llmapi import KvCacheConfig

# FP8 KV cache (2x memory reduction vs FP16)
kv_cache_config = KvCacheConfig(dtype='fp8')

# NVFP4 KV cache (4x memory reduction vs FP16)
# Requires offline quantization with ModelOpt
kv_cache_config = KvCacheConfig(dtype='nvfp4')

# Auto-detect from model config (default)
kv_cache_config = KvCacheConfig(dtype='auto')
  • Default precision
  • Best accuracy
  • Highest memory usage

Host Memory Offloading

Extend effective cache capacity by offloading to CPU memory:
from tensorrt_llm.llmapi import KvCacheConfig

kv_cache_config = KvCacheConfig(
    host_cache_size=1024*1024*1024,  # 1GB host cache
    secondary_offload_min_priority=35  # Don't offload high-priority blocks
)
  1. When a GPU block is evicted, it’s first copied to host memory
  2. The block remains reusable from host memory
  3. When reused, it’s copied back to GPU memory
  4. Blocks below secondary_offload_min_priority are evicted directly (not offloaded)
  • Long-running sessions with high reuse potential
  • Scenarios with repeating prompts or patterns
  • When GPU memory is limited but CPU memory is abundant
  • Trade PCIe bandwidth for additional cache capacity

Block Retention and Eviction

Priority-Based Eviction

Blocks are assigned priority scores (0-100, higher = more important):
  • Lowest priority blocks are evicted first
  • Only leaf blocks (no descendants in radix tree) can be evicted
  • Prioritized LRU (Least Recently Used) within each priority level
from tensorrt_llm.llmapi import KvCacheRetentionConfig, TokenRangeRetentionConfig
from tensorrt_llm.inputs import TextPrompt

# Configure retention policy
retention_config = KvCacheRetentionConfig(
    # High priority for tokens 0-100
    token_range_retention_config=[
        TokenRangeRetentionConfig(
            token_start=0,
            token_end=100,
            priority=80,
            duration_ms=60000  # 60 seconds
        )
    ],
    # Medium priority for decoded tokens
    decode_retention_policy=50,
    decode_duration_ms=30000  # 30 seconds
)

prompt = TextPrompt(
    prompt="Your text here",
    kv_cache_retention_config=retention_config
)
Priority reverts to the default (35) after duration_ms elapses from when the block first becomes available for reuse.

Partial Reuse

Partial block matching enables more flexible reuse:
from tensorrt_llm.llmapi import KvCacheConfig

kv_cache_config = KvCacheConfig(
    enable_partial_reuse=True,        # Enable partial matching (default)
    copy_on_partial_reuse=True        # Copy partially matched blocks (default)
)
  • Creates a new block and copies matched tokens
  • Allows multiple requests to partially reuse the same block
  • Higher memory usage but more flexible
  • Reuses block in-place (no copy)
  • Only works if no other request is using the block
  • Lower memory usage but less flexible

Security Features

KV Cache Salting

Cache salting prevents unauthorized reuse of cached KV states:
from tensorrt_llm.inputs import TextPrompt

# Different users/tenants use different salts
user1_prompt = TextPrompt(
    prompt="What is AI?",
    cache_salt="user_12345"
)

user2_prompt = TextPrompt(
    prompt="What is AI?",
    cache_salt="user_67890"
)
Only requests with matching cache_salt values can share cached KV blocks. This prevents prompt theft attacks where malicious users might try to infer information from other users’ cached states.

Multimodal UUID Support

For multimodal models, custom UUIDs enable deterministic cache management:
from tensorrt_llm.inputs import TextPrompt

prompt = TextPrompt(
    prompt="Describe these images.",
    multi_modal_data={"image": [image1, image2]},
    multi_modal_uuids={"image": ["image-uuid-001", "image-uuid-002"]}
)
Cache Correctness: When a UUID is provided, the cache key is computed from both the UUID and content using BLAKE3(UUID || Content). This ensures:
  • Different content always produces different cache entries
  • Same content with different UUIDs produces different entries (user isolation)
  • Original UUID is preserved in KV cache events for external tracking

Advanced Features

Attention Window Size

Configure per-layer attention window sizes:
from tensorrt_llm.llmapi import KvCacheConfig

# Full attention (4096) + sliding window (256) pattern
kv_cache_config = KvCacheConfig(
    max_attention_window=[4096, 256]  # Repeats for all layers
)
If the list length is less than the number of layers, the pattern repeats. For example, [4096, 256] means layer 0 has full attention (4096), layer 1 has sliding window (256), layer 2 has full attention, etc.

Grouped Query Attention (GQA)

TensorRT-LLM automatically optimizes KV cache for GQA/MQA:
  • MHA (Multi-Head Attention): One group per head
  • MQA (Multi-Query Attention): Single group for all heads
  • GQA (Grouped Query Attention): Intermediate grouping
Blocks only allocate space for the actual number of KV head groups, reducing memory usage.

Streaming and Long Context

Support for models with limited attention windows:
from tensorrt_llm.llmapi import KvCacheConfig

# StreamingLLM-style configuration
kv_cache_config = KvCacheConfig(
    max_attention_window=[2048],  # Sliding window size
    # sink_token_length is model-specific and auto-configured
)
  • Treats KV cache as a circular buffer
  • Stores only the last N tokens (N = attention_window_size)
  • New tokens overwrite least recently used cache
  • Reduces memory for very long sequences
  • Keeps first S “sink tokens” always in cache
  • Applies sliding window to remaining tokens
  • Uses positions within cache (not original text) for RoPE
  • Enables efficient long-text generation

Cross-Request Reuse Example

Here’s how to maximize KV cache reuse across requests:
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig, KvCacheRetentionConfig, TokenRangeRetentionConfig
from tensorrt_llm.inputs import TextPrompt

# Configure aggressive caching
kv_cache_config = KvCacheConfig(
    free_gpu_memory_fraction=0.9,
    dtype='fp8',                    # Reduce memory usage
    enable_block_reuse=True,
    host_cache_size=2*1024**3,      # 2GB host cache for overflow
)

llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", kv_cache_config=kv_cache_config)

# High-priority system prompt
system_prompt_retention = KvCacheRetentionConfig(
    token_range_retention_config=[
        TokenRangeRetentionConfig(
            token_start=0,
            token_end=1000,
            priority=90,        # Very high priority
            duration_ms=None    # Never expires
        )
    ]
)

prompts = [
    TextPrompt(
        prompt="System: You are a helpful assistant.\n\nUser: What is 2+2?",
        kv_cache_retention_config=system_prompt_retention
    ),
    TextPrompt(
        prompt="System: You are a helpful assistant.\n\nUser: What is 3+3?",
        kv_cache_retention_config=system_prompt_retention
    )
]

# Second prompt reuses "System: You are a helpful assistant." from cache
for output in llm.generate(prompts):
    print(output.text)

Best Practices

  • Start with free_gpu_memory_fraction=0.9 (default)
  • Use FP8 KV cache on Hopper+ GPUs for 2x memory savings
  • Enable host offloading for multi-turn conversations
  • Monitor cache hit rates and adjust retention policies
  • Assign high priority (80-90) to system prompts and common prefixes
  • Use medium priority (50-60) for user-specific context
  • Set low priority (20-30) for temporary or one-time prompts
  • Use duration_ms=None for prompts that should never expire
  • Always use cache_salt in multi-tenant environments
  • Set different salts per user, session, or security domain
  • Use multimodal UUIDs for deterministic cache management
  • Monitor for unusual cache hit patterns (potential attacks)

Additional Resources

KV Cache API Reference

Complete API documentation for KvCacheConfig

Retention Config Example

Advanced retention policy examples

Host Offloading Example

Complete host offloading implementation

Build docs developers (and LLMs) love