Skip to main content
Radix Cache is a key optimization in Mini-SGLang that enables the reuse of Key-Value (KV) cache for shared prefixes across requests. This significantly reduces redundant computation and improves serving efficiency, especially for workloads with common prefixes.

Overview

Adopting the original design from SGLang, Mini-SGLang implements a Radix Cache to manage the KV cache intelligently. Instead of recomputing the KV cache for every request, the system identifies and reuses cached computations from previous requests with matching prefixes. Radix Attention Illustration of Radix Attention from LMSYS Blog

How It Works

Radix Tree Structure

The Radix Cache is built on a radix tree data structure where:
  • Nodes represent token sequences and their corresponding KV cache indices
  • Edges connect sequences that share common prefixes
  • Root is always protected and represents the empty sequence
Each node in the tree stores:
  • Key: The token sequence (input_ids)
  • Value: The physical page indices in the KV cache
  • Reference count: Number of active requests using this prefix
  • Timestamp: Last access time for eviction decisions

Prefix Matching

When a new request arrives, the system walks the radix tree to find the longest matching prefix:
def match_prefix(self, input_ids: torch.Tensor) -> MatchResult:
    node, prefix_len = self._tree_walk(input_ids)
    return MatchResult(RadixCacheHandle(prefix_len, node))
The tree walk algorithm:
  1. Starts at the root node
  2. Checks children for matching prefixes (using page-aligned keys)
  3. Walks down the tree while matches are found
  4. Splits nodes if partial matches occur
  5. Returns the matched node and prefix length
Prefix matching is page-aligned based on the system’s page size. This means only complete pages are matched and reused, ensuring efficient memory management.

Cache Insertion

After processing a request, the system inserts the new sequence into the cache:
def insert_prefix(self, input_ids: torch.Tensor, indices: torch.Tensor) -> InsertResult:
    insert_len = align_down(len(input_ids), self.page_size)
    input_ids, indices = input_ids[:insert_len], indices[:insert_len]
    node, prefix_len = self._tree_walk(input_ids)
    if prefix_len != insert_len:
        # Create new node for the unmatched suffix
        new_node = RadixTreeNode(self.key_fn)
        new_node.set_key_value(input_ids[prefix_len:], indices[prefix_len:].clone())
        new_node.set_parent(node)
        self.evictable_size += new_node.length
        node = new_node
    return InsertResult(prefix_len, RadixCacheHandle(insert_len, node))

Reference Counting

The system uses reference counting to protect active cache entries from eviction:
  • Locked (ref_count > 0): The prefix is actively used by one or more requests and cannot be evicted
  • Unlocked (ref_count = 0): The prefix is cached but not in use, available for eviction if memory is needed
When a request is scheduled:
def lock_handle(self, handle: BaseCacheHandle, unlock: bool = False) -> None:
    node = handle.node
    while not node.is_root():
        if unlock:
            node.ref_count -= 1
            if node.ref_count == 0:
                self.evictable_size += node.length
                self.protected_size -= node.length
        else:
            if node.ref_count == 0:
                self.evictable_size -= node.length
                self.protected_size += node.length
            node.ref_count += 1
        node = node.parent

Cache Eviction

When memory is needed, the system evicts the least recently used (LRU) leaf nodes:
  1. Collect evictable leaves: Find all leaf nodes with ref_count == 0
  2. Heapify by timestamp: Sort by least recently accessed
  3. Evict until target met: Remove nodes and update parent leaf status
  4. Cascade eviction: If parent becomes a leaf with ref_count == 0, it becomes evictable too
def evict(self, size: int) -> torch.Tensor:
    leave_nodes = self._collect_leave_nodes_for_evict()
    heapq.heapify(leave_nodes)
    evicted_indices: List[torch.Tensor] = []
    evicted_size = 0
    
    while evicted_size < size:
        node = heapq.heappop(leave_nodes)
        evicted_size += node.length
        evicted_indices.append(node.value)
        self.evictable_size -= node.length
        parent = node.parent
        del parent.children[self.key_fn(node._key)]
        # Parent may become evictable after child removal
        if parent.is_leaf() and parent.ref_count == 0:
            heapq.heappush(leave_nodes, parent)
    
    return torch.cat(evicted_indices)
The eviction strategy is LRU (Least Recently Used) based on the timestamp of last access. Nodes are only evicted when they have no active references and are leaf nodes in the tree.

Benefits

Reduced Redundant Computation

Consider a chatbot scenario where multiple users ask similar questions:
User 1: "What is the capital of France?"
User 2: "What is the capital of Germany?"
User 3: "What is the capital of Italy?"
All three requests share the prefix “What is the capital of”. With Radix Cache:
  • First request: Computes KV cache for entire sequence
  • Second request: Reuses KV cache for “What is the capital of”, only computes “Germany?”
  • Third request: Reuses KV cache for “What is the capital of”, only computes “Italy?”

Memory Efficiency

Shared prefixes are stored only once in the cache, reducing memory footprint compared to per-request caching.

Lower Latency

Reusing cached KV values eliminates redundant forward passes through transformer layers for matching prefixes, reducing time-to-first-token (TTFT).

Configuration

Radix Cache is enabled by default in Mini-SGLang. To switch to a naive cache management strategy:
python -m minisgl --model "Qwen/Qwen3-0.6B" --cache naive

Page Size

The page size affects cache granularity. Larger pages reduce metadata overhead but may waste memory on partial matches:
python -m minisgl --model "Qwen/Qwen3-0.6B" --page-size 16
For workloads with many shared prefixes (e.g., multi-turn conversations, few-shot prompting, batch inference with common system prompts), Radix Cache can provide significant performance improvements.

Implementation Details

Key Function

The key function determines how sequences are indexed in the tree:
def _get_key_fn(page_size: int) -> KEY_FN:
    if page_size == 1:
        return lambda x: x[0].item()
    return lambda x: tuple(x[:page_size].tolist())
For page_size=1, the key is a single token. For larger page sizes, the key is a tuple of the first page_size tokens.

Node Splitting

When a partial match occurs, the existing node is split to create a new branch:
def split_at(self, pos: int) -> RadixTreeNode:
    parent = self.parent
    
    # Create new node for the matched prefix
    new_node = RadixTreeNode(self.key_fn, self.timestamp)
    new_node.set_key_value(self._key[:pos], self._value[:pos])
    new_node.set_parent(parent)
    new_node.ref_count = self.ref_count
    
    # Update current node for the suffix
    self.set_key_value(self._key[pos:], self._value[pos:])
    self.set_parent(new_node)
    
    return new_node

Fast Key Comparison

Mini-SGLang uses a custom CUDA kernel for fast token sequence comparison:
from minisgl.kernel import fast_compare_key

def get_match_len(self, input_ids: torch.Tensor) -> int:
    # Compare key and input_ids, find the first diff
    return fast_compare_key(self._key, input_ids)

Architecture

Understand how cache management fits into the system

Chunked Prefill

Learn about memory-efficient prefill strategies

Build docs developers (and LLMs) love