Skip to main content
Quantized KV cache reduces the memory footprint of key-value cache storage by using lower-precision data types (FP8 or FP4) instead of the default model precision (BF16). This enables longer context lengths or more concurrent requests within the same memory budget.

Overview

During autoregressive generation, LLMs cache previously computed key-value pairs to avoid redundant calculations. The KV cache typically consumes a significant portion of GPU memory, especially for long sequences. Quantized KV cache is a memory optimization technique that primarily benefits throughput by allowing more tokens to be cached, but may introduce minimal accuracy degradation depending on the quantization format used.
Performance Warning: When quantized KV cache must be dequantized before use in attention operations, performance can be extremely slow if dequantization is not fused with the attention kernel. Always verify that your chosen attention backend supports quantized KV cache. Backends without fused support may experience significant throughput degradation, potentially negating the memory benefits.Backend Support: Not all attention backends support quantized KV cache. Refer to Attention Backends for compatibility details.

Supported Formats

FP8 Format

OCP (Open Compute Project) specifies two common 8-bit floating point formats:

E5M2

5 exponent bits, 2 mantissa bits
  • Larger dynamic range (±57344.0)
  • Lower precision
  • Better for values with wide range

E4M3

4 exponent bits, 3 mantissa bits
  • Higher precision
  • Smaller dynamic range (±240.0)
  • Recommended for most use cases

FP4 Format (Experimental)

FP4 quantization is currently experimental.
OCP (Open Compute Project) specifies MXFP4 (Microscaling FP4), a 4-bit floating-point format: E2M1 (1 sign bit, 2 exponent bits, 1 mantissa bit):
  • Uses block-based microscaling where tensors are divided into blocks of consecutive elements
  • Each block shares a single 8-bit exponential scaling factor
  • OCP specifies blocks of 32 elements; SGLang currently uses blocks of 16 elements
  • Scaling factors computed dynamically on-the-fly (no pre-quantization required)

Usage

Enabling Quantized KV Cache

python3 -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-R1-0528 \
  --kv-cache-dtype fp8_e5m2

Scaling Factors

FP8 quantization requires scaling factors to properly quantize and dequantize the KV cache.
Currently, only per-tensor (scalar) scaling factors are supported.
Scaling factors can be:
  1. Loaded from checkpoints: Pre-quantized models (e.g., ModelOpt) may include k_scale and v_scale parameters that are automatically loaded
  2. Provided via JSON: Supply scaling factors via --quantization-param-path
JSON format:
{
  "kv_cache": {
    "dtype": "float8_e4m3fn",
    "scaling_factor": {
      "0": {
        "0": 1.0,
        "1": 1.0
      }
    }
  }
}
Where outer keys are tensor parallel ranks and inner keys are layer indices.
If scaling factors are not provided and not found in the checkpoint, it will default to 1.0, which may cause accuracy issues.
FP4 (MXFP4): Unlike FP8, FP4 quantization handles scaling factors automatically on-the-fly during quantization and dequantization. No pre-quantized models or external scaling factor files are required—the block-based scaling factors are computed dynamically as needed.

Performance Considerations

Memory Savings

Quantized KV cache provides significant memory savings:
FormatTokens Supported (vs BF16)
BF161.00× (baseline)
FP8~2.00×
FP4~3.56×
FP4 and FP8 quantization require additional memory for block-based scaling factors, which reduces the effective memory savings compared to the raw bit-width reduction. The ratios above account for this overhead.
This enables:
  • Longer context lengths within the same memory budget
  • More concurrent requests for improved throughput
  • Better GPU utilization by reducing KV cache memory pressure

Accuracy Impact

FP8 Accuracy

FP8 E4M3 quantization typically introduces minimal accuracy degradation. The impact depends on:
  • Model architecture
  • Sequence length
  • Quantization format (E4M3 generally has better accuracy than E5M2)

FP4 Accuracy

FP4 (MXFP4) quantization provides significant memory savings with varying accuracy impact depending on model size and dataset complexity. Large Models (200B+ parameters) On large-scale models, FP4 maintains accuracy close to FP8/BF16, especially on simpler datasets:
ModelDatasetBF16FP8 E4M3FP4 E2M1
Qwen3-235B-A22Bgsm8k0.91680.91810.9186
Qwen3-235B-A22Baime250.77330.73330.6000
Qwen3-235B-A22Bgpqa_diamond0.70100.68990.6778
DeepSeek-R1-0528gsm8k0.91570.91540.9124
DeepSeek-R1-0528aime250.50670.49340.4000
DeepSeek-R1-0528gpqa_diamond0.77070.76970.7273
Smaller Models (<200B parameters) On smaller models, FP4 shows more pronounced accuracy drops, particularly on challenging datasets:
ModelDatasetBF16FP8 E4M3FP4 E2M1
GPT-OSS-120Bgsm8k0.91610.91630.9152
GPT-OSS-120Baime250.75330.76670.3533
GPT-OSS-120Bgpqa_diamond0.50810.54340.3202
Key Observations:
  • Simple datasets (e.g., gsm8k): FP4 maintains accuracy close to FP8/BF16 across model sizes
  • Model size matters: Large models (200B+ parameters) generally tolerate FP4 quantization better than smaller models
  • Context length: Accuracy degradation may be more pronounced in long-context scenarios due to accumulation of quantization error
Evaluate FP4 accuracy on your specific model and workload. Large models on simpler tasks typically show minimal degradation, while smaller models or complex reasoning tasks may require FP8 or BF16 for acceptable accuracy.

Backend Compatibility

Not all attention backends support quantized KV cache. Refer to the support matrix:

MHA Backends

BackendFP8 KV CacheFP4 KV Cache
FlashInfer
FA3 (FlashAttention 3)
FA4 (FlashAttention 4)
Triton
Torch Native (SDPA)
TRTLLM MHA
AITER (ROCm)

MLA Backends

BackendFP8 KV CacheFP4 KV Cache
FlashInfer MLA
FlashMLA
Cutlass MLA
TRTLLM MLA (Blackwell)
FA3 (FlashAttention 3)
FA4
Backends without native quantized KV cache support will require dequantization before attention operations, which can severely impact performance. Always choose a backend with fused dequantization support.

Examples

DeepSeek-R1 with FP8 KV Cache

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend flashmla \
  --kv-cache-dtype fp8_e4m3 \
  --trust-remote-code

Qwen3-235B with FP4 KV Cache

python3 -m sglang.launch_server \
  --tp 8 \
  --model Qwen/Qwen3-235B-A22B-Instruct-2507 \
  --attention-backend fa4 \
  --kv-cache-dtype fp4_e2m1 \
  --page-size 128 \
  --trust-remote-code

Pre-quantized Model with Custom Scaling Factors

python3 -m sglang.launch_server \
  --model-path your-model \
  --kv-cache-dtype fp8_e4m3 \
  --quantization-param-path scaling_factors.json

Best Practices

Use Pre-quantized Models

Prefer models quantized offline with scaling factors included in the checkpoint for best accuracy.

Choose the Right Format

Use fp8_e4m3 for better accuracy (recommended), fp8_e5m2 for larger dynamic range, or fp4_e2m1 for maximum memory savings (experimental).

Check Backend Compatibility

Verify that your chosen attention backend supports quantized KV cache with fused dequantization.

Evaluate Accuracy

Test FP4/FP8 accuracy on your specific workload before production deployment, especially for complex reasoning tasks.

Troubleshooting

Performance Degradation

If quantized KV cache degrades performance:
  1. Check backend support: Verify your attention backend supports quantized KV cache with fused dequantization
  2. Try different formats: FP8 may perform better than FP4 on some backends
  3. Monitor memory bandwidth: Quantization reduces memory footprint but increases compute

Accuracy Issues

If you observe accuracy degradation:
  1. Verify scaling factors: Ensure scaling factors are properly loaded or provided
  2. Try FP8 E4M3: Switch from FP4 or E5M2 to E4M3 for better accuracy
  3. Evaluate on your dataset: Test on representative samples before full deployment
  4. Consider model size: Smaller models may require higher precision

Missing Scaling Factors

If scaling factors default to 1.0:
  1. Check checkpoint: Verify the model includes k_scale and v_scale parameters
  2. Provide JSON file: Use --quantization-param-path to supply custom scaling factors
  3. Use pre-quantized models: Download models from Unsloth, NVIDIA ModelOpt, or NeuralMagic collections

See Also