Skip to main content

Memory Management

This guide covers SGLang’s memory management system, including KV cache allocation, radix caching, and memory optimizations.

Overview

SGLang’s memory system manages:
  • Model weights (static, loaded once)
  • KV cache (dynamic, per request)
  • Activation memory (temporary, per batch)
  • Workspace buffers (scratch space for kernels)
Location: python/sglang/srt/mem_cache/

Memory Layout

GPU Memory Breakdown

Total GPU Memory (e.g., 80GB on A100)
├── Model Weights (static) ~ 15GB
│   └── Fixed after loading
├── KV Cache Pool (dynamic) ~ 60GB
│   ├── Request 1: 2GB
│   ├── Request 2: 1.5GB
│   ├── Request 3: 3GB
│   └── Free space: 53.5GB
├── Activation Memory (temp) ~ 4GB
│   └── Reused across batches
└── Workspace Buffers ~ 1GB
    └── Scratch space for kernels

Memory Allocation at Startup

def init_memory_pool(self, config):
    """Initialize memory pools."""
    # Calculate available memory
    total_mem = torch.cuda.get_device_properties(0).total_memory
    model_mem = self.get_model_memory()
    
    # Reserve memory for KV cache
    kv_cache_mem = total_mem * config.mem_fraction_static - model_mem
    
    # Create memory pool
    self.token_to_kv_pool = TokenToKVPool(
        size=kv_cache_mem,
        token_size=self.get_kv_cache_token_size(),
    )
    
    logger.info(f"KV cache pool size: {kv_cache_mem / 1e9:.2f} GB")
    logger.info(f"Max tokens: {kv_cache_mem // self.get_kv_cache_token_size()}")

Token-to-KV Pool

Architecture

The TokenToKVPool manages KV cache allocation at token granularity.
class TokenToKVPool:
    """Manages KV cache memory at token level."""
    
    def __init__(self, size, token_size):
        self.size = size
        self.token_size = token_size  # Bytes per token's KV cache
        self.max_tokens = size // token_size
        
        # Pre-allocate GPU memory
        self.kv_data = torch.empty(
            (self.max_tokens, num_layers, 2, num_heads, head_dim),
            dtype=torch.float16,
            device="cuda"
        )
        
        # Free list of token slots
        self.free_slots = set(range(self.max_tokens))
    
    def allocate(self, num_tokens):
        """Allocate KV cache for num_tokens."""
        if len(self.free_slots) < num_tokens:
            return None  # OOM
        
        # Allocate from free list
        allocated = [self.free_slots.pop() for _ in range(num_tokens)]
        return allocated
    
    def free(self, slots):
        """Free KV cache slots."""
        self.free_slots.update(slots)

Per-Token KV Cache Size

For a model with:
  • num_layers = 32
  • num_kv_heads = 8 (GQA)
  • head_dim = 128
  • dtype = fp16 (2 bytes)
kv_cache_per_token = (
    num_layers *      # 32
    2 *               # K and V
    num_kv_heads *    # 8
    head_dim *        # 128
    2                 # bytes (fp16)
)
= 32 * 2 * 8 * 128 * 2 = 131,072 bytes = 128 KB
For 10,000 tokens: 10,000 * 128 KB = 1.28 GB

Radix Cache

Radix Tree for Prefix Sharing

Radix cache uses a tree structure to share KV cache across requests with common prefixes.
class RadixCache:
    """Radix tree for prefix caching."""
    
    class Node:
        def __init__(self):
            self.children = {}  # token_id -> Node
            self.kv_indices = []  # Slots in token_to_kv_pool
            self.ref_count = 0  # How many requests reference this
            self.last_access = time.time()
    
    def __init__(self):
        self.root = self.Node()
        self.total_nodes = 0
    
    def match(self, tokens):
        """Find longest matching prefix."""
        node = self.root
        matched_indices = []
        
        for token in tokens:
            if token in node.children:
                node = node.children[token]
                matched_indices.extend(node.kv_indices)
                node.last_access = time.time()
            else:
                break
        
        return matched_indices
    
    def insert(self, tokens, kv_indices):
        """Insert new prefix into tree."""
        node = self.root
        
        for i, token in enumerate(tokens):
            if token not in node.children:
                node.children[token] = self.Node()
                self.total_nodes += 1
            
            node = node.children[token]
            node.kv_indices = kv_indices[:i+1]
            node.ref_count += 1

Example: Prefix Sharing

# Request 1: "Translate to French: Hello"
tokens1 = [1054, 284, 2823, 25, 15496]  # "Translate to French: Hello"
kv1 = pool.allocate(len(tokens1))  # [0, 1, 2, 3, 4]
radix_cache.insert(tokens1, kv1)

# Request 2: "Translate to French: Goodbye" (shares prefix)
tokens2 = [1054, 284, 2823, 25, 7197, 29474]  # "Translate to French: Goodbye"
matched = radix_cache.match(tokens2)  # Returns [0, 1, 2, 3] (shared prefix)
remaining = len(tokens2) - len(matched)  # 2 tokens
kv2_new = pool.allocate(remaining)  # [5, 6]
kv2 = matched + kv2_new  # [0, 1, 2, 3, 5, 6]
Memory Saved: 4 tokens * 128 KB = 512 KB per request

Memory Allocation Strategies

Lazy Allocation

Allocate KV cache incrementally as tokens are generated:
class Request:
    def allocate_kv_cache(self, num_new_tokens):
        """Allocate KV cache for new tokens."""
        # Try to match prefix first
        matched = self.radix_cache.match(self.input_ids)
        
        if matched:
            self.kv_indices = matched
            num_new_tokens -= len(matched)
        
        # Allocate remaining
        if num_new_tokens > 0:
            new_indices = self.pool.allocate(num_new_tokens)
            if new_indices is None:
                raise OutOfMemoryError("KV cache pool exhausted")
            
            self.kv_indices.extend(new_indices)

Eager Eviction

Free cache immediately when request finishes:
def finish_request(self, req):
    """Clean up request resources."""
    # Decrement reference counts in radix tree
    self.radix_cache.decrement_refs(req.input_ids)
    
    # Free KV cache slots
    self.pool.free(req.kv_indices)
    
    # Remove from running batch
    self.running_batch.remove(req)

Cache Eviction Policy

When memory is full, evict least recently used (LRU) cached prefixes:
def evict_cache(self, required_slots):
    """Evict cached prefixes to free memory."""
    # Find eviction candidates (ref_count == 0)
    candidates = []
    self._collect_candidates(self.radix_cache.root, candidates)
    
    # Sort by last access time (LRU)
    candidates.sort(key=lambda node: node.last_access)
    
    # Evict until enough memory
    freed = 0
    for node in candidates:
        if freed >= required_slots:
            break
        
        # Free this node's KV cache
        self.pool.free(node.kv_indices)
        freed += len(node.kv_indices)
        
        # Remove from tree
        self._remove_node(node)
    
    return freed >= required_slots

KV Cache Formats

Contiguous Format

All tokens’ KV cache stored contiguously:
# Shape: [num_tokens, num_layers, 2, num_kv_heads, head_dim]
kv_cache = torch.zeros(
    (seq_len, num_layers, 2, num_kv_heads, head_dim),
    dtype=torch.float16,
    device="cuda"
)

# Access K for layer i, token j
k = kv_cache[j, i, 0]  # [num_kv_heads, head_dim]

# Access V for layer i, token j
v = kv_cache[j, i, 1]  # [num_kv_heads, head_dim]

Paged Format

KV cache split into fixed-size pages (e.g., PagedAttention):
page_size = 16  # tokens per page
num_pages = (seq_len + page_size - 1) // page_size

# Shape: [num_pages, page_size, num_layers, 2, num_kv_heads, head_dim]
kv_cache = torch.zeros(
    (num_pages, page_size, num_layers, 2, num_kv_heads, head_dim),
    dtype=torch.float16,
    device="cuda"
)

# Access via page table
page_table = [0, 1, 2, ...]  # Maps logical pages to physical pages
token_idx = 25
page_idx = token_idx // page_size  # 1
offset = token_idx % page_size     # 9
physical_page = page_table[page_idx]
k = kv_cache[physical_page, offset, layer_idx, 0]

Memory Optimizations

1. Quantized KV Cache

Store KV cache in lower precision:
class QuantizedKVCache:
    """INT8-quantized KV cache."""
    
    def __init__(self, *args, **kwargs):
        # Store in INT8 instead of FP16
        self.kv_data = torch.zeros(
            (max_tokens, num_layers, 2, num_kv_heads, head_dim),
            dtype=torch.int8,  # 1 byte instead of 2
            device="cuda"
        )
        # Store scale factors for dequantization
        self.scales = torch.zeros(
            (max_tokens, num_layers, 2, num_kv_heads, 1),
            dtype=torch.float16,
            device="cuda"
        )
    
    def store(self, layer_idx, kv_fp16):
        """Quantize and store KV cache."""
        # Quantize to INT8
        scale = kv_fp16.abs().max() / 127.0
        kv_int8 = (kv_fp16 / scale).round().to(torch.int8)
        
        # Store quantized values and scale
        self.kv_data[...] = kv_int8
        self.scales[...] = scale
    
    def load(self, layer_idx):
        """Dequantize and load KV cache."""
        kv_int8 = self.kv_data[...]
        scale = self.scales[...]
        
        # Dequantize to FP16
        kv_fp16 = kv_int8.to(torch.float16) * scale
        return kv_fp16
Memory Savings: 50% (INT8 vs FP16)

2. HiCache (L3 Storage)

Offload cold KV cache to CPU or SSD:
class HiCache:
    """Hierarchical cache with L1 (GPU), L2 (CPU), L3 (SSD)."""
    
    def __init__(self):
        self.l1_cache = {}  # GPU: Hot cache
        self.l2_cache = {}  # CPU: Warm cache
        self.l3_cache = {}  # SSD: Cold cache
    
    def get(self, key):
        """Get KV cache from hierarchy."""
        # Check L1 (GPU)
        if key in self.l1_cache:
            return self.l1_cache[key]
        
        # Check L2 (CPU)
        if key in self.l2_cache:
            kv = self.l2_cache[key].cuda()  # Move to GPU
            self.l1_cache[key] = kv
            return kv
        
        # Check L3 (SSD)
        if key in self.l3_cache:
            kv = torch.load(self.l3_cache[key]).cuda()
            self.l1_cache[key] = kv
            return kv
        
        return None
    
    def put(self, key, kv):
        """Store KV cache in hierarchy."""
        # Always insert to L1
        self.l1_cache[key] = kv
        
        # Evict if L1 is full
        if len(self.l1_cache) > L1_MAX_SIZE:
            # Evict LRU to L2 (CPU)
            evict_key = self._get_lru_key()
            self.l2_cache[evict_key] = self.l1_cache[evict_key].cpu()
            del self.l1_cache[evict_key]

3. Multi-Query Attention (MQA)

Reduce KV cache size by sharing across query heads:
# Standard attention
num_kv_heads = 32  # Same as num_query_heads
kv_size = 32 * head_dim * 2  # K and V

# Grouped-query attention (GQA)
num_kv_heads = 8   # Fewer than num_query_heads
kv_size = 8 * head_dim * 2   # 4x smaller

# Multi-query attention (MQA)
num_kv_heads = 1   # Single KV head
kv_size = 1 * head_dim * 2   # 32x smaller

Monitoring and Debugging

Memory Usage Statistics

def get_memory_stats(self):
    """Get memory usage statistics."""
    return {
        "total_kv_cache": self.pool.size,
        "used_kv_cache": self.pool.size - len(self.pool.free_slots) * self.pool.token_size,
        "free_kv_cache": len(self.pool.free_slots) * self.pool.token_size,
        "cache_hit_rate": self.cache_hits / self.total_requests,
        "radix_tree_nodes": self.radix_cache.total_nodes,
        "gpu_memory_allocated": torch.cuda.memory_allocated(),
        "gpu_memory_reserved": torch.cuda.memory_reserved(),
    }

Visualize Memory Usage

import matplotlib.pyplot as plt

def plot_memory_usage(stats_history):
    """Plot memory usage over time."""
    times = [s["time"] for s in stats_history]
    used = [s["used_kv_cache"] / 1e9 for s in stats_history]  # GB
    free = [s["free_kv_cache"] / 1e9 for s in stats_history]
    
    plt.figure(figsize=(10, 6))
    plt.plot(times, used, label="Used")
    plt.plot(times, free, label="Free")
    plt.xlabel("Time (s)")
    plt.ylabel("Memory (GB)")
    plt.legend()
    plt.title("KV Cache Memory Usage")
    plt.show()

Best Practices

1. Set Appropriate Memory Fraction

# Leave headroom for PyTorch overhead
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B \
  --mem-fraction-static 0.85  # 85% for model + KV cache

2. Enable RadixCache

# Enable prefix caching (enabled by default)
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B
  # RadixCache is on by default

# Disable if not beneficial
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B \
  --disable-radix-cache

3. Use Chunked Prefill

# Prevent large prefills from blocking decode
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.1-8B \
  --max-prefill-tokens 16384 \
  --prefill-chunk-size 512

Resources

Next Steps