Skip to main content
SGLang supports a wide variety of attention backends, each with different strengths and tradeoffs. Selecting an optimal attention backend is crucial for maximizing performance based on your model, hardware, and use case.
Different backends excel in various scenarios. Choose based on your model architecture, hardware platform, and workload characteristics. Not all backends are supported on all platforms and model architectures.

Automatic Backend Selection

If you don’t specify --attention-backend, SGLang makes a best effort to automatically select the most performant backend based on your hardware and model architecture.

MHA Models (e.g., Llama, Qwen)

  • Hopper (H100, H200): Defaults to fa3 if using CUDA 12.3+ and model configuration is supported
  • Blackwell (B200): Defaults to trtllm_mha, unless using speculative decoding with topk > 1
  • Other Architectures (Ampere, Ada): Defaults to flashinfer if available; otherwise falls back to triton

MLA Models (e.g., DeepSeek V3)

  • Hopper: Defaults to fa3 (requires CUDA 12.3+)
  • Blackwell: Defaults to trtllm_mla
  • Other Architectures: Defaults to triton

Backend Support Matrix

MHA (Multi-Head Attention) Backends

BackendPage Size > 1 (native)FP8 KV CacheFP4 KV CacheSpec topk=1Spec topk>1Sliding WindowMultiModal
FlashInfer
FA3 (FlashAttention 3)
FA4 (FlashAttention 4)128
Triton
Torch Native (SDPA)
FlexAttention (PyTorch)
TRTLLM MHA16, 32, 64
Dual Chunk FlashAttention
AITER (ROCm)
Wave (ROCm)
Ascend (NPU)
Intel XPU
Intel AMX (CPU)

MLA (Multi-Head Latent Attention) Backends

BackendNative Page SizesFP8 KV CacheFP4 KV CacheChunked Prefix CacheSpec topk=1Spec topk>1
FlashInfer MLA1
FlashMLA64
Cutlass MLA128
TRTLLM MLA (Blackwell)32, 64
FA3 (FlashAttention 3)n/a⚠️ (page_size=1 only)
Tritonn/a⚠️ (page_size=1 only)
FA41
Ascend MLA (NPU)128
Multimodal attention is selected by --mm-attention-backend. The “MultiModal” column indicates whether a corresponding multimodal implementation exists for that backend family.
Page Size and Prefix Cache: Page size controls how many tokens are grouped into a KV cache block. For the prefix cache to take effect, the number of tokens must fill at least one complete page. For example, if your prompt is only 32 tokens and page_size = 64, it won’t fill a complete page and cannot be matched in the prefix cache. Use page_size = 1 for maximum prefix reuse (token-level matching).

Backend Descriptions

FlashInfer

Best for: General-purpose MHA models on non-Hopper GPUs (A100, A40) High-performance attention implementation with broad feature support including FP8 KV cache, speculative decoding, and sliding window attention.
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend flashinfer

FlashAttention 3 (FA3)

Best for: Hopper GPUs (H100, H200, H20) Default backend for Hopper machines. Optimized for SM90 architecture with excellent performance for both MHA and MLA models.
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend fa3

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-V3 \
  --attention-backend fa3 \
  --trust-remote-code

FlashAttention 4 (FA4)

Best for: Blackwell GPUs (B200) and FP4 KV cache workloads Supports both prefill and decode on SM90 (Hopper) and SM100 (Blackwell). On Hopper, requires page_size = 128.
FA4 on Hopper (SM90): FA4 decode speed decreases as sequence length grows due to lack of SplitKV support. At batch=1 compared to FA3 on H100: ~-10% at 2K tokens, ~-18% at 4K, ~-31% at 8K, ~-49% at 16K. Larger batch sizes reduce the gap. Blackwell (SM100) is not affected.
python3 -m sglang.launch_server \
  --model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 \
  --attention-backend fa4 \
  --page-size 128 \
  --trust-remote-code

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --prefill-attention-backend fa4 \
  --trust-remote-code

FlashMLA

Best for: MLA models with FP8 KV cache on Hopper Specialized backend for MLA architecture with native support for FP8 and FP4 KV cache. Requires page_size = 64.
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend flashmla \
  --kv-cache-dtype fp8_e4m3 \
  --trust-remote-code

TRTLLM MLA

Best for: Blackwell architecture (B200) with MLA models Optimized for Blackwell GPUs with excellent performance for MLA models. Supports FP8 and FP4 KV cache.
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend trtllm_mla \
  --kv-cache-dtype fp8_e4m3 \
  --trust-remote-code

TRTLLM MHA

Best for: Blackwell architecture (B200) with MHA models Optimized for Blackwell GPUs. Supports page_size of 16, 32, or 64.
python3 -m sglang.launch_server \
  --tp 4 \
  --model Qwen/Qwen3.5-35B-A3B-FP8 \
  --attention-backend trtllm_mha \
  --trust-remote-code

# XQA backend for SM90 and SM120 (H20, H200, 5090)
python3 -m sglang.launch_server \
  --tp 4 \
  --model Qwen/Qwen3.5-35B-A3B-FP8 \
  --decode-attention-backend trtllm_mha \
  --page-size 64 \
  --trust-remote-code

Triton

Best for: Development, debugging, and FP4 KV cache Flexible Triton-based implementation supporting FP4 KV cache and various advanced features. Good fallback option for unsupported configurations.
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend triton

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-V3 \
  --attention-backend triton \
  --trust-remote-code

Cutlass MLA

High-performance MLA backend using CUTLASS kernels. Requires page_size = 128.
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend cutlass_mla \
  --trust-remote-code

Platform-Specific Backends

AMD ROCm

AITER: Recommended for ROCm platforms
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend aiter
Wave: Alternative ROCm backend
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend wave

Ascend NPU

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend ascend

Intel XPU

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend intel_xpu

Other Backends

Torch Native (SDPA): PyTorch’s scaled dot-product attention
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend torch_native
FlexAttention: PyTorch’s FlexAttention API
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend flex_attention
Dual Chunk FlashAttention: For long-context models
python3 -m sglang.launch_server \
  --model Qwen/Qwen2.5-14B-Instruct-1M \
  --attention-backend dual_chunk_flash_attn

GDN Attention Backends

GDN (Gated Delta Network) is a linear attention mechanism with O(n) complexity, used in hybrid models that alternate GDN linear attention layers with standard full attention layers (e.g., Qwen 3.5, Qwen 3 Next, Jet Nemotron, Jet VLM). GDN is not selected via --attention-backend; it is automatically activated when the model architecture requires it. The GDN linear attention layers have their own kernel backends, selected via --linear-attn-backend (default: triton).
BackendDecodePrefill / ExtendSpec Decoding (Target Verify)
Triton (CUDA)
Triton (AMD/ROCm)
Triton (NPU)
Triton (CPU)
CuTe DSL (CUDA only)
Platform Constraints for GDN Models:
  • Blackwell (B200): triton, trtllm_mha, or fa4 only
  • NPU (Ascend): ascend only
  • AMD (ROCm): triton recommended
  • Other CUDA (Hopper, Ampere): auto-selection works; no special constraints

Hybrid Attention (Experimental)

You can mix-and-match attention backends for prefill and decode. This is useful when one backend excels at prefill and another excels at decode.
# Example: Prefill with FA4, Decode with TRTLLM MLA (Blackwell)
python3 -m sglang.launch_server \
  --model-path nvidia/DeepSeek-R1-FP4 \
  --tp 8 \
  --attention-backend trtllm_mla \
  --moe-runner-backend flashinfer_trtllm \
  --quantization modelopt_fp4 \
  --prefill-attention-backend fa4

Speculative Decoding with Hybrid Attention

The backend used for draft decoding and target verification depends on --speculative-attention-mode:
  • --speculative-attention-mode decode (recommended): draft/verify use the decode backend
  • --speculative-attention-mode prefill (default): draft/verify use the prefill backend
Constraints:
  • If any attention backend is trtllm_mha, speculative decoding supports only --speculative-eagle-topk 1
  • For paged MHA backends with --page-size > 1 and --speculative-eagle-topk > 1, only flashinfer is supported
  • CUDA Graph: the decode backend is always captured; the prefill backend is captured only when --speculative-attention-mode prefill
If you set only one of --prefill-attention-backend or --decode-attention-backend, the unspecified phase inherits --attention-backend. If both are specified and differ, SGLang automatically enables a hybrid wrapper.

Backend Selection Guide

Hopper GPUs (H100/H200)

Use FA3 for both MHA and MLA models. Best overall performance on SM90 architecture.

Blackwell GPUs (B200)

Use TRTLLM MLA for MLA models and TRTLLM MHA for MHA models. Optimized for SM100 architecture.

Ampere/Ada GPUs (A100/A40)

Use FlashInfer for best compatibility and performance on older architectures.

FP4 KV Cache

Use FA4 on Blackwell, FlashMLA on Hopper for MLA, or Triton as fallback.

FP8 KV Cache

Use FlashMLA or FA3 on Hopper, TRTLLM on Blackwell, FlashInfer on Ampere/Ada.

Long Context

Use Dual Chunk FlashAttention for million-token contexts, or FA3/FlashInfer with sliding window.

Best Practices

  1. Let SGLang auto-select: Unless you have specific requirements, let SGLang automatically choose the backend
  2. Match page size to backend: Check backend requirements for page size (e.g., FA4 requires 128 on Hopper)
  3. Consider KV cache format: Choose backends that support your desired KV cache dtype (FP8/FP4/BF16)
  4. Test on your workload: Different backends may perform differently depending on batch size, sequence length, and model size
  5. Monitor for graph breaks: Some backends work better with CUDA graphs than others

See Also