RadixAttention is SGLang’s core attention mechanism that enables automatic detection and reuse of common token sequence prefixes across different requests. This dramatically improves throughput and reduces memory usage for workloads with shared context.
What is RadixAttention?
RadixAttention combines two key innovations:
- Radix Tree Data Structure: Organizes KV cache as a tree where common prefixes are shared
- Attention Mechanism: Efficiently computes attention over the shared tree structure
This allows multiple requests to share the same cached key-value states for identical token sequences, significantly reducing both memory usage and computation.
The Radix Tree
The RadixCache in python/sglang/srt/mem_cache/radix_cache.py implements a prefix tree (trie) data structure:
class TreeNode:
def __init__(self, id=None, priority=0):
self.children = defaultdict(TreeNode)
self.parent: TreeNode = None
self.key: RadixKey = None # Token sequence
self.value: torch.Tensor = None # KV cache indices
self.lock_ref = 0 # Reference count
Reference: python/sglang/srt/mem_cache/radix_cache.py:117-146
Example Tree Structure
Consider three requests:
- Request 1: “What is the capital of France?”
- Request 2: “What is the capital of Germany?”
- Request 3: “What is the weather today?”
The radix tree would look like:
root
|
[What is the]
/ \
[capital of] [weather today]
/ \
[France?] [Germany?]
The shared prefix “What is the” is stored only once in memory and reused by all three requests.
Key Operations
Prefix Matching
When a new request arrives, RadixCache finds the longest cached prefix:
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
"""Find the longest cached prefix of key in the radix tree.
Returns:
MatchResult with device_indices (KV cache locations) and
last_node (terminal node of matched prefix)
"""
key = params.key
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int64, device=self.device)
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_node,
)
Reference: python/sglang/srt/mem_cache/radix_cache.py:371-441
Prefix matching happens at page granularity when page_size > 1. If page_size=16, prefixes are aligned to 16-token boundaries.
Prefix Insertion
After a request completes, its tokens and KV cache are inserted into the tree:
def insert(self, params: InsertParams) -> InsertResult:
"""Insert a key-value pair into the radix tree.
Handles:
- Creating new nodes for novel suffixes
- Splitting existing nodes when prefix diverges
- Updating reference counts
"""
key = params.key
value = params.value
priority = params.priority
prefix_len = self._insert_helper(self.root_node, key, value, priority)
return InsertResult(prefix_len=prefix_len)
Reference: python/sglang/srt/mem_cache/radix_cache.py:443-457
Node Splitting
When a match ends in the middle of a node, the node is split:
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
"""Split a node at split_len position.
Before: parent -> child[0:N]
After: parent -> new_node[0:split_len] -> child[split_len:N]
"""
new_node = TreeNode(priority=child.priority)
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len].clone()
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:].clone()
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
Reference: python/sglang/srt/mem_cache/radix_cache.py:687-706
Memory Management
Reference Counting
Each node maintains a reference count:
def inc_lock_ref(self, node: TreeNode):
"""Increment reference count along path to root.
Nodes with lock_ref > 0 are protected from eviction.
"""
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.key)
self.protected_size_ += len(node.key)
delta -= len(node.key)
node.lock_ref += 1
self._update_leaf_status(node)
node = node.parent
return delta
Reference: python/sglang/srt/mem_cache/radix_cache.py:607-620
Eviction Policies
When memory is full, nodes are evicted based on configurable policies:
- LRU (Least Recently Used): Evicts oldest accessed nodes
- LFU (Least Frequently Used): Evicts least accessed nodes
- FIFO (First In First Out): Evicts oldest inserted nodes
- MRU (Most Recently Used): Evicts most recent nodes
- FILO (First In Last Out): Stack-based eviction
- Priority: Evicts based on request priority
def evict(self, params: EvictParams) -> EvictResult:
"""Evict tokens from cache to free memory."""
num_tokens = params.num_tokens
leaves = list(self.evictable_leaves)
# Build heap based on eviction policy
eviction_heap = [
(self.eviction_strategy.get_priority(node), node)
for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0
while num_evicted < num_tokens and len(eviction_heap):
_priority, x = heapq.heappop(eviction_heap)
# Free KV cache
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
return EvictResult(num_tokens_evicted=num_evicted)
Reference: python/sglang/srt/mem_cache/radix_cache.py:578-605
Attention Computation
The RadixAttention layer (python/sglang/srt/layers/radix_attention.py) computes attention over the tree structure:
class RadixAttention(nn.Module):
def forward(
self,
q, # Query tensor
k, # Key tensor (for new tokens)
v, # Value tensor (for new tokens)
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
# Reshape tensors
if k is not None:
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
# Dispatch to attention backend
return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache, **kwargs
)
Reference: python/sglang/srt/layers/radix_attention.py:99-135
The attention backend (FlashInfer, FlashAttention, etc.) handles the actual computation, accessing cached KV states through the indices provided by RadixCache.
Memory Savings
With shared prefixes, memory usage scales with unique content rather than total tokens:
- Without prefix caching: 3 requests × 1000 tokens = 3000 tokens in memory
- With prefix caching: 800 shared + 3 × 200 unique = 1400 tokens (53% savings)
Computation Savings
Cached KV states don’t need to be recomputed:
- Prefill time: Only compute attention for new (uncached) tokens
- TTFT (Time to First Token): Dramatically reduced for requests with cached prefixes
RadixAttention is most effective for workloads with:
- Long system prompts
- Few-shot examples
- Document-based QA
- Multi-turn conversations
Integration with Scheduler
The scheduler automatically uses RadixCache during request processing:
class Scheduler:
def get_next_batch_to_run(self):
# Match prefixes for waiting requests
for req in self.waiting_queue:
match_result = self.tree_cache.match_prefix(
MatchPrefixParams(key=RadixKey(
token_ids=req.origin_input_ids,
extra_key=req.extra_key
))
)
req.prefix_indices = match_result.device_indices
req.last_node = match_result.last_device_node
# Schedule based on prefix matches
self.policy.calc_priority(self.waiting_queue)
Reference: python/sglang/srt/managers/schedule_policy.py:182-240
Page Alignment
For efficient memory management, prefixes can be aligned to page boundaries:
def page_align_keys(key: list, page_size: int) -> list:
"""Align keys to page_size boundaries.
Example: If page_size=16 and key has 35 tokens,
returns first 32 tokens (2 complete pages)
"""
if page_size == 1:
return key
page_aligned_len = len(key) // page_size * page_size
return key[:page_aligned_len]
Reference: python/sglang/srt/mem_cache/radix_cache.py:110-114
Partial pages at the end of sequences are not cached. This ensures page-granular sharing and efficient memory allocation.
Configuration
Configure RadixAttention behavior:
# Disable radix cache (use simpler caching)
--disable-radix-cache
# Set eviction policy
--radix-eviction-policy lru # Options: lru, lfu, fifo, mru, filo, priority
# Enable EAGLE speculative decoding (uses bigram keys)
--speculative-algorithm eagle
# Set page size for alignment
--page-size 16 # Default: 1 (no alignment)
Limitations
- Exact Match Required: Token sequences must match exactly (including special tokens)
- Memory Overhead: Tree structure adds small overhead vs flat arrays
- Eviction Cost: Finding eviction candidates requires heap operations
- Locking Granularity: Entire paths are locked, not individual nodes