Skip to main content
RadixAttention is SGLang’s core attention mechanism that enables automatic detection and reuse of common token sequence prefixes across different requests. This dramatically improves throughput and reduces memory usage for workloads with shared context.

What is RadixAttention?

RadixAttention combines two key innovations:
  1. Radix Tree Data Structure: Organizes KV cache as a tree where common prefixes are shared
  2. Attention Mechanism: Efficiently computes attention over the shared tree structure
This allows multiple requests to share the same cached key-value states for identical token sequences, significantly reducing both memory usage and computation.

The Radix Tree

The RadixCache in python/sglang/srt/mem_cache/radix_cache.py implements a prefix tree (trie) data structure:
class TreeNode:
    def __init__(self, id=None, priority=0):
        self.children = defaultdict(TreeNode)
        self.parent: TreeNode = None
        self.key: RadixKey = None          # Token sequence
        self.value: torch.Tensor = None    # KV cache indices
        self.lock_ref = 0                  # Reference count
Reference: python/sglang/srt/mem_cache/radix_cache.py:117-146

Example Tree Structure

Consider three requests:
  • Request 1: “What is the capital of France?”
  • Request 2: “What is the capital of Germany?”
  • Request 3: “What is the weather today?”
The radix tree would look like:
                    root
                     |
              [What is the]
                /         \
          [capital of]   [weather today]
           /        \
      [France?]  [Germany?]
The shared prefix “What is the” is stored only once in memory and reused by all three requests.

Key Operations

Prefix Matching

When a new request arrives, RadixCache finds the longest cached prefix:
def match_prefix(self, params: MatchPrefixParams) -> MatchResult:
    """Find the longest cached prefix of key in the radix tree.
    
    Returns:
        MatchResult with device_indices (KV cache locations) and
        last_node (terminal node of matched prefix)
    """
    key = params.key
    value, last_node = self._match_prefix_helper(self.root_node, key)
    
    if value:
        value = torch.cat(value)
    else:
        value = torch.empty((0,), dtype=torch.int64, device=self.device)
    
    return MatchResult(
        device_indices=value,
        last_device_node=last_node,
        last_host_node=last_node,
    )
Reference: python/sglang/srt/mem_cache/radix_cache.py:371-441
Prefix matching happens at page granularity when page_size > 1. If page_size=16, prefixes are aligned to 16-token boundaries.

Prefix Insertion

After a request completes, its tokens and KV cache are inserted into the tree:
def insert(self, params: InsertParams) -> InsertResult:
    """Insert a key-value pair into the radix tree.
    
    Handles:
    - Creating new nodes for novel suffixes
    - Splitting existing nodes when prefix diverges
    - Updating reference counts
    """
    key = params.key
    value = params.value
    priority = params.priority
    
    prefix_len = self._insert_helper(self.root_node, key, value, priority)
    return InsertResult(prefix_len=prefix_len)
Reference: python/sglang/srt/mem_cache/radix_cache.py:443-457

Node Splitting

When a match ends in the middle of a node, the node is split:
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
    """Split a node at split_len position.
    
    Before: parent -> child[0:N]
    After:  parent -> new_node[0:split_len] -> child[split_len:N]
    """
    new_node = TreeNode(priority=child.priority)
    new_node.children = {self.get_child_key_fn(key[split_len:]): child}
    new_node.parent = child.parent
    new_node.lock_ref = child.lock_ref
    new_node.key = child.key[:split_len]
    new_node.value = child.value[:split_len].clone()
    
    child.parent = new_node
    child.key = child.key[split_len:]
    child.value = child.value[split_len:].clone()
    
    new_node.parent.children[self.get_child_key_fn(key)] = new_node
    return new_node
Reference: python/sglang/srt/mem_cache/radix_cache.py:687-706

Memory Management

Reference Counting

Each node maintains a reference count:
def inc_lock_ref(self, node: TreeNode):
    """Increment reference count along path to root.
    
    Nodes with lock_ref > 0 are protected from eviction.
    """
    delta = 0
    while node != self.root_node:
        if node.lock_ref == 0:
            self.evictable_size_ -= len(node.key)
            self.protected_size_ += len(node.key)
            delta -= len(node.key)
        node.lock_ref += 1
        self._update_leaf_status(node)
        node = node.parent
    return delta
Reference: python/sglang/srt/mem_cache/radix_cache.py:607-620

Eviction Policies

When memory is full, nodes are evicted based on configurable policies:
  • LRU (Least Recently Used): Evicts oldest accessed nodes
  • LFU (Least Frequently Used): Evicts least accessed nodes
  • FIFO (First In First Out): Evicts oldest inserted nodes
  • MRU (Most Recently Used): Evicts most recent nodes
  • FILO (First In Last Out): Stack-based eviction
  • Priority: Evicts based on request priority
def evict(self, params: EvictParams) -> EvictResult:
    """Evict tokens from cache to free memory."""
    num_tokens = params.num_tokens
    leaves = list(self.evictable_leaves)
    
    # Build heap based on eviction policy
    eviction_heap = [
        (self.eviction_strategy.get_priority(node), node) 
        for node in leaves
    ]
    heapq.heapify(eviction_heap)
    
    num_evicted = 0
    while num_evicted < num_tokens and len(eviction_heap):
        _priority, x = heapq.heappop(eviction_heap)
        
        # Free KV cache
        self.token_to_kv_pool_allocator.free(x.value)
        num_evicted += len(x.value)
        self._delete_leaf(x)
    
    return EvictResult(num_tokens_evicted=num_evicted)
Reference: python/sglang/srt/mem_cache/radix_cache.py:578-605

Attention Computation

The RadixAttention layer (python/sglang/srt/layers/radix_attention.py) computes attention over the tree structure:
class RadixAttention(nn.Module):
    def forward(
        self,
        q,          # Query tensor
        k,          # Key tensor (for new tokens)
        v,          # Value tensor (for new tokens)
        forward_batch: ForwardBatch,
        save_kv_cache: bool = True,
        **kwargs,
    ):
        # Reshape tensors
        if k is not None:
            k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
            v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
        
        # Dispatch to attention backend
        return forward_batch.attn_backend.forward(
            q, k, v, self, forward_batch, save_kv_cache, **kwargs
        )
Reference: python/sglang/srt/layers/radix_attention.py:99-135
The attention backend (FlashInfer, FlashAttention, etc.) handles the actual computation, accessing cached KV states through the indices provided by RadixCache.

Performance Benefits

Memory Savings

With shared prefixes, memory usage scales with unique content rather than total tokens:
  • Without prefix caching: 3 requests × 1000 tokens = 3000 tokens in memory
  • With prefix caching: 800 shared + 3 × 200 unique = 1400 tokens (53% savings)

Computation Savings

Cached KV states don’t need to be recomputed:
  • Prefill time: Only compute attention for new (uncached) tokens
  • TTFT (Time to First Token): Dramatically reduced for requests with cached prefixes
RadixAttention is most effective for workloads with:
  • Long system prompts
  • Few-shot examples
  • Document-based QA
  • Multi-turn conversations

Integration with Scheduler

The scheduler automatically uses RadixCache during request processing:
class Scheduler:
    def get_next_batch_to_run(self):
        # Match prefixes for waiting requests
        for req in self.waiting_queue:
            match_result = self.tree_cache.match_prefix(
                MatchPrefixParams(key=RadixKey(
                    token_ids=req.origin_input_ids,
                    extra_key=req.extra_key
                ))
            )
            req.prefix_indices = match_result.device_indices
            req.last_node = match_result.last_device_node
        
        # Schedule based on prefix matches
        self.policy.calc_priority(self.waiting_queue)
Reference: python/sglang/srt/managers/schedule_policy.py:182-240

Page Alignment

For efficient memory management, prefixes can be aligned to page boundaries:
def page_align_keys(key: list, page_size: int) -> list:
    """Align keys to page_size boundaries.
    
    Example: If page_size=16 and key has 35 tokens,
             returns first 32 tokens (2 complete pages)
    """
    if page_size == 1:
        return key
    page_aligned_len = len(key) // page_size * page_size
    return key[:page_aligned_len]
Reference: python/sglang/srt/mem_cache/radix_cache.py:110-114
Partial pages at the end of sequences are not cached. This ensures page-granular sharing and efficient memory allocation.

Configuration

Configure RadixAttention behavior:
# Disable radix cache (use simpler caching)
--disable-radix-cache

# Set eviction policy
--radix-eviction-policy lru  # Options: lru, lfu, fifo, mru, filo, priority

# Enable EAGLE speculative decoding (uses bigram keys)
--speculative-algorithm eagle

# Set page size for alignment
--page-size 16  # Default: 1 (no alignment)

Limitations

  1. Exact Match Required: Token sequences must match exactly (including special tokens)
  2. Memory Overhead: Tree structure adds small overhead vs flat arrays
  3. Eviction Cost: Finding eviction candidates requires heap operations
  4. Locking Granularity: Entire paths are locked, not individual nodes