Skip to main content
HiCache (Hierarchical KV Cache) extends SGLang’s RadixAttention with a three-tier caching system that dramatically improves performance for long-context and multi-turn conversation scenarios. By intelligently managing KV caches across GPU memory (L1), host memory (L2), and distributed storage (L3), HiCache addresses the fundamental capacity bottleneck that limits cache hit rates in conventional systems.

Why HiCache?

In large language model inference, the prefill phase is often time-consuming: input sequences must first be converted into Key-Value cache (KV cache) for subsequent decoding. When multiple requests share the same prefix, the KV cache for that prefix is identical. By caching and reusing these shared KV caches, redundant computation can be avoided. SGLang’s RadixAttention leverages idle GPU memory to cache and reuse prefix KV caches. HiCache extends this idea to host memory and distributed storage, inspired by the classic three-level cache design of modern CPUs:
  • L1 Cache: GPU memory (fast, private per instance)
  • L2 Cache: Host memory (larger capacity, private per instance)
  • L3 Cache: Distributed storage (shared across cluster)
This hierarchy enables HiCache to fully exploit idle storage space while integrating distributed cache systems such as Mooncake, 3FS, NIXL, and AIBrix KVCache for global KV cache storage and scheduling. For detailed benchmark results, see the HiCache blog post.

System Architecture

HiRadixTree: Metadata Organization

HiCache builds upon the RadixTree structure introduced in RadixAttention. In RadixAttention, each node of the RadixTree corresponds to the KV cache of a consecutive span of tokens in GPU memory. HiRadixTree extends this:
  • Each node corresponds to the KV cache of a span of consecutive tokens
  • Records where that KV cache is stored (L1 GPU memory, L2 CPU memory, L3 storage, or multiple tiers)
  • For local storage (L1/L2), maintains precise metadata including exact storage addresses
  • For L3 storage, queries the backend in real time to retrieve metadata on demand

Three-Phase Workflow

HiCache operates through three key phases:

1. Local Match

When a new request arrives, HiCache first searches local L1 and L2 caches for matching KV caches:
  • Traverses the HiRadixTree from the root node
  • Matches the token sequence prefix at page granularity (when page_size > 1)
  • Automatically splits nodes for exact boundaries when matches terminate mid-node
  • Returns a continuous prefix with L1 portion followed by L2 portion
  • Extremely fast since no data copying is required

2. Prefetch from L3

For parts not found in L1 or L2, HiCache queries L3 storage and prefetches data proactively: Trigger Conditions:
  • L3 hit length exceeds threshold (default: 256 tokens, configurable)
Prefetch Strategies:
  • best_effort: Terminates immediately when GPU can execute prefill (minimal latency)
  • wait_complete: Waits for all prefetch operations (highest cache hit rate)
  • timeout: Balances latency and cache hit rate with configurable timeout
For the timeout strategy, the timeout is computed as:
timeout = prefetch_timeout_base + prefetch_timeout_per_ki_token * num_token_to_fetch / 1024

3. Write-back

After prefill computation, newly generated KV caches are written back to L2 and L3: Write-back Policies:
  • write_through: Immediately writes to all tiers (strongest caching benefits when bandwidth is sufficient)
  • write_through_selective: Writes only hot data exceeding access frequency threshold (reduces I/O overhead)
  • write_back: Writes to next level only on eviction (suitable for limited storage capacity)
Cross-instance Sharing: Data written from L2 to L3 is shared across all SGLang instances in the cluster, significantly improving cache hit rates within the same memory budget. HiCache Workflow

Performance Optimizations

Zero-Copy Data Transfers

HiCache supports passing memory addresses and sizes directly when transferring data from L2 to L3, minimizing data copies and improving performance.

Optimized Memory Layouts

HiCache supports multiple memory layouts for L2 host memory:
  • layer_first: Compatible with GPU computation kernels (default for GPU memory)
  • page_first: All KV cache data for the same page in contiguous memory, enabling zero-copy transfers to L3
  • page_first_direct: Groups all tokens of a given layer within a page, allowing aggregated page-layer transfers from L2 to GPU
HiCache Memory Layout

CPU-to-GPU Transfer Optimizations

  • Compute-Transfer Overlap: During prefill, concurrently loads layer N+1 while computing layer N
  • GPU-assisted I/O Kernels: Custom kernels achieve up to 3× higher transfer speed compared to baseline cudaMemcpyAsync

Multi-Rank Synchronization

During multi-GPU parallel computation (e.g., tensor parallelism), HiCache uses all_reduce operations to ensure consistent states across ranks:
  • all_reduce(op=min) ensures all ranks obtain same L3 hit count
  • Guarantees consensus on prefix length of successfully retrieved KV cache

MLA Optimization

For MLA (Multi-Layer Attention) models under multi-TP, all ranks hold identical KV data. HiCache optimizes write-back by having only one rank initiate the operation, preventing redundant storage.

Configuration Parameters

Core Parameters

--enable-hierarchical-cache
boolean
Enable hierarchical cache functionality. Required to use HiCache.
--hicache-ratio
float
Ratio of host KV cache memory pool size to device pool size. For example, a value of 2 means the host memory pool is twice as large as the device memory pool. Must be greater than 1.
--hicache-size
float
Size of host KV cache memory pool in gigabytes per rank. Overrides --hicache-ratio if set. For example, --hicache-size 30 allocates 30GB (1GB = 1e9 bytes) per rank. With 8 ranks, total memory is 240GB.
In general, a larger HiCache size leads to higher cache hit rate, which improves prefill performance. However, the relationship is not linear. Once most reusable KV data is cached, further increases yield marginal gains. Set based on workload characteristics.
--page-size
integer
Number of tokens per page. Determines granularity of KV cache storage and retrieval. Larger pages reduce metadata overhead and improve I/O efficiency but may lower cache hit rate when only part of a page matches. For long common prefixes, larger pages improve performance; for diverse prefixes, smaller pages are better.

Prefetch Configuration

--hicache-storage-prefetch-policy
string
default:"timeout"
Controls when prefetching from storage should stop. Options:
  • best_effort: Prefetch without blocking (minimal latency)
  • wait_complete: Wait for prefetch completion (highest hit rate)
  • timeout: Balance latency and hit rate with timeout (recommended for production)

Write Policy

--hicache-write-policy
string
default:"write_through"
Controls how data is written from faster to slower memory tiers. Options:
  • write_through: Immediately writes to all tiers (strongest caching benefits)
  • write_through_selective: Uses hit-count tracking to back up only frequently accessed data
  • write_back: Writes to slower tiers only on eviction (reduces I/O load)

I/O Backend

--hicache-io-backend
string
default:"kernel"
Choose I/O backend for KV cache transfer between CPU and GPU. Options:
  • direct: Standard CUDA memory copy operations
  • kernel: GPU-assisted I/O kernels (recommended for better performance)

Memory Layout

--hicache-mem-layout
string
default:"layer_first"
Memory layout for the host memory pool. Options:
  • layer_first: Compatible with GPU computation kernels
  • page_first: Optimized for I/O efficiency with zero-copy (only compatible with kernel backend)
  • page_first_direct: Aggregated page-layer transfers (compatible with direct backend and FA3)

Storage Backend

--hicache-storage-backend
string
Choose the storage backend for the L3 tier. Built-in options:
  • file: Simple file-based storage for demonstration
  • mooncake: High-performance RDMA-based caching system
  • hf3fs: Kubernetes-native distributed storage
  • nixl: Unified API for various storage plugins
  • aibrix: Production-ready KVCache offloading framework
  • dynamic: Custom backend loaded dynamically
--hicache-storage-backend-extra-config
string
Extra configuration for storage backend. Can be either:
  • A JSON string: '{"prefetch_threshold":512}'
  • A config file path (prepend with @): "@config.toml"
--enable-lmcache
boolean
Use LMCache as an alternative hierarchical cache solution.

Storage Backends

HiCache provides unified interfaces for various L3 storage backends:

Mooncake

High-performance caching system leveraging RDMA and multi-NIC resources for zero-copy, ultra-fast data transfers.
export MOONCAKE_TE_META_DATA_SERVER="http://127.0.0.1:8080/metadata"
export MOONCAKE_GLOBAL_SEGMENT_SIZE=816043786240
export MOONCAKE_PROTOCOL="rdma"
export MOONCAKE_DEVICE="$DEVICE_LIST"
export MOONCAKE_MASTER=127.0.0.1:50051

python3 -m sglang.launch_server \
  --model-path $MODEL_PATH \
  --tp 8 \
  --page-size 64 \
  --enable-hierarchical-cache \
  --hicache-ratio 2 \
  --hicache-mem-layout page_first_direct \
  --hicache-io-backend direct \
  --hicache-storage-backend mooncake \
  --hicache-write-policy write_through \
  --hicache-storage-prefetch-policy timeout

HF3FS

Kubernetes-native distributed storage solution with operator-based deployment.
python3 -m sglang.launch_server \
  --model-path /xxx/DeepSeek-R1/ \
  --tp 8 \
  --page-size 64 \
  --mem-fraction-static 0.85 \
  --enable-hierarchical-cache \
  --hicache-ratio 2 \
  --hicache-mem-layout page_first_direct \
  --hicache-io-backend direct \
  --hicache-write-policy write_through \
  --hicache-storage-backend hf3fs \
  --hicache-storage-prefetch-policy wait_complete

NIXL

Provides unified API for accessing various storage plugins including DeepSeek’s 3FS, GPU Direct Storage (GDS), and Amazon S3-compatible object storage.

AIBrix KVCache

Production-ready KVCache offloading framework enabling efficient memory tiering and low-overhead cross-engine reuse.

LMCache

An alternative efficient KV cache layer for enterprise-scale LLM inference, providing a different solution to HiCache.

Runtime Attach/Detach

HiCache supports dynamically attaching/detaching the L3 storage backend at runtime without restarting the server.
Attach/detach operations require the service to be idle (no running or queued requests). Operations will fail fast (HTTP 400) if this condition is not met.

Query Current Backend

curl -s http://127.0.0.1:30000/hicache/storage-backend

Attach Storage Backend

curl -s -X PUT http://127.0.0.1:30000/hicache/storage-backend \
  -H 'Content-Type: application/json' \
  -d '{
    "hicache_storage_backend": "mooncake",
    "hicache_storage_backend_extra_config_json": "{\"master_server_address\":\"127.0.0.1:50051\"}",
    "hicache_storage_prefetch_policy": "timeout"
  }'

Detach Storage Backend

curl -s -X DELETE http://127.0.0.1:30000/hicache/storage-backend
Detach only stops using the L3 storage backend and stops prefetch/backup threads. It does not automatically delete data stored in remote backends.

Integration with PD Disaggregation

HiCache works seamlessly with PD (Prefill-Decode) Disaggregation deployment mode:
  1. Prefill-only HiCache: Enable HiCache only on Prefill nodes for KV cache sharing among Prefill instances
  2. Full HiCache with async offloading: Enable HiCache on Prefill nodes and async offloading on Decode nodes for multi-turn dialogue scenarios
# Prefill node with HiCache
python3 -m sglang.launch_server \
  --model-path /xxx/DeepSeek-R1/ \
  --tp 8 \
  --enable-hierarchical-cache \
  --hicache-storage-backend hf3fs \
  --disaggregation-mode prefill \
  --disaggregation-transfer-backend mooncake

# Decode node with async offloading
python3 -m sglang.launch_server \
  --model-path /xxx/DeepSeek-R1/ \
  --tp 8 \
  --hicache-storage-backend hf3fs \
  --disaggregation-decode-enable-offload-kvcache \
  --disaggregation-mode decode \
  --disaggregation-transfer-backend mooncake

Heterogeneous TP Support

HiCache storage supports cross-cluster KV reuse when different deployments use different TP sizes (e.g., tp=4 and tp=8) sharing the same storage backend namespace.
# Example: heterogeneous TP = {4, 8}, so LCM = 8
--hicache-storage-backend-extra-config '{"tp_lcm_size": 8}'
Set tp_lcm_size to the least common multiple (LCM) of all TP sizes that will share the same HiCache storage. For MHA models with Mooncake and page_head layout, HiCache will split head shards based on tp_lcm_size to make keys reusable across heterogeneous TP deployments.

Custom Storage Backend Integration

To integrate a custom storage backend:
  1. Implement three core methods:
    • get(key): Retrieve value by key
    • exists(key): Check key existence
    • set(key, value): Store key-value pair
  2. Use dynamic loading:
python3 -m sglang.launch_server \
  --model-path your-model \
  --enable-hierarchical-cache \
  --hicache-storage-backend dynamic \
  --hicache-storage-backend-extra-config '{"backend_name":"custom_backend_name", "module_path": "your_module_path", "class_name": "YourHiCacheClassName"}'

Best Practices

Memory Configuration

Set --hicache-ratio or --hicache-size based on workload. Larger values improve hit rate but with diminishing returns after hot data is cached.

Layout Selection

Use page_first_direct with direct backend for best compatibility. Use page_first with kernel backend for maximum I/O efficiency.

Prefetch Policy

Use timeout policy in production with tuned timeout parameters to meet SLOs while maximizing cache hits.

Write Policy

Use write_through when bandwidth is sufficient. Use write_through_selective to reduce I/O overhead for mixed workloads.

See Also