Skip to main content
TensorRT-LLM implements highly optimized attention mechanisms for autoregressive models, supporting Multi-Head Attention (MHA), Multi-Query Attention (MQA), and Grouped-Query Attention (GQA) with multiple backend implementations.

Attention Variants

MHA

Multi-Head AttentionOne KV head per query head. Original Transformer design with maximum expressiveness.

MQA

Multi-Query AttentionSingle KV head shared across all query heads. Minimal KV cache memory.

GQA

Grouped-Query AttentionKV heads divided into groups. Balances memory and model quality.
All three attention variants are described in the Attention Is All You Need, Multi-Query Attention, and Grouped-Query Attention papers.

Attention Backends

TensorRT-LLM provides three attention backends optimized for different use cases:

TRT-LLM Backend (Default)

The default and most optimized backend for production use:
from tensorrt_llm import LLM

# TRT-LLM backend is the default
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")

# Or explicitly specify
llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    attn_backend="trtllm"
)
Features:
  • Flash Attention for context phase
  • Masked MHA with multi-block optimization for generation
  • XQA kernels for MQA/GQA models
  • FP8 input/output and KV cache quantization
  • Fused QKV input support
  • RoPE fusion
  • Paged and contiguous KV cache
Recommended for all production deployments. Offers the best performance and supports all TensorRT-LLM features.

FlashInfer Backend

Performance-optimized backend with FlashInfer library:
from tensorrt_llm import LLM

llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    attn_backend="flashinfer"
)
Features:
  • In-flight batching
  • Paged KV cache
  • FP8 quantization for inputs and KV cache
  • RoPE fusion

Vanilla Backend

Reference implementation for debugging and baseline comparisons:
from tensorrt_llm import LLM

llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    attn_backend="vanilla"
)
Not recommended for production use. Limited optimizations and primarily useful for debugging and validation.

Context Phase Optimizations

Flash Attention

For the context (prefill) phase, TensorRT-LLM uses Flash Attention kernels:
  • Short sequences: Vanilla MHA implementation
  • Long sequences: Flash Attention algorithm (reduces memory from O(N²) to O(N))
  • Fused softmax and attention computation
  • Minimal intermediate tensor materialization
Based on FlashAttention: Fast and Memory-Efficient Exact Attention
  • Tiling-based algorithm
  • IO-aware attention
  • Reduces HBM reads/writes

FP8 Context FMHA

When FP8 quantization is enabled, context attention is further accelerated:
from tensorrt_llm import LLM
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo

quant_config = QuantConfig(
    quant_algo=QuantAlgo.FP8,
    kv_cache_quant_algo=QuantAlgo.FP8
)

llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    quant_config=quant_config,
    # use_paged_context_fmha=True is enabled by default
)
FP8 Paged Context FMHA is supported on Ada, Hopper, and Blackwell GPUs.

Generation Phase Optimizations

Masked Multi-Head Attention

The generation (decode) phase uses a specialized masked MHA kernel:
  • On-the-fly QKV bias addition
  • Fused RoPE application
  • Dequantization/quantization support
  • Multi-block mode for low GPU occupancy

Multi-Block Mode

When batch size and number of heads are small, multi-block mode distributes work across multiple CUDA thread blocks:
# Multi-block is always enabled (automatic heuristic)
# Typically beneficial when: batch_size * num_heads < GPU_multiprocessor_count
Multi-block mode is triggered automatically by internal heuristics. It activates when sequences are long enough and GPU occupancy is low.

XQA Optimization

XQA (eXtended Query Attention) is a specialized kernel for MQA/GQA in the generation phase:
import os

# Force XQA kernel when supported (optional)
os.environ['TRTLLM_FORCE_XQA'] = '1'
Support Matrix:
  • FP16 / BF16 compute data type
  • FP16 / BF16 / FP8 / INT8 KV cache data type
  • Paged KV cache (8 / 16 / 32 / 64 / 128 tokens per block)
XQA is enabled by default with automatic heuristics. Set TRTLLM_FORCE_XQA=1 to always use XQA when the model configuration is supported.

In-Flight Batching

TensorRT-LLM supports continuous batching of requests:
  • Context-phase sequences can be batched with generation-phase sequences
  • Reduces latency and improves GPU utilization
  • Requires packed (non-padded) input tensors
from tensorrt_llm import LLM

# In-flight batching is enabled by default
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct")

# Submit multiple requests with different phases
prompts = [
    "What is AI?",           # Context phase (long)
    "Explain quantum",       # Context phase (long)
    # ... continuing generation from previous requests
]

outputs = llm.generate(prompts)
Important: Sequences in context phase must appear before sequences in generation phase in the input tensor.

Chunked Context

Long contexts can be split into chunks for better batching:
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig

# Chunked context requires paged KV cache
kv_cache_config = KvCacheConfig(
    # Paged KV cache is enabled by default
)

llm = LLM(
    model="meta-llama/Llama-3.1-8B-Instruct",
    kv_cache_config=kv_cache_config
)
Benefits:
  • Context chunks batch with generation tokens
  • Increases total throughput
  • Removes constraints on input length
  • Better GPU utilization
Chunk size (except the last chunk) must be an integer multiple of the KV cache block size.

Advanced Features

Rotary Positional Embedding (RoPE)

RoPE is fused into the attention operator:
# RoPE is automatically applied based on model config
# Supports both GPT-NeoX and GPT-J variants
Supported RoPE Types:
  • rope_gpt_neox: Standard GPT-NeoX RoPE
  • rope_gptj: GPT-J variant

ALiBi (Attention with Linear Biases)

ALiBi slopes are computed on-the-fly:
# ALiBi is automatically enabled for models that use it (e.g., MPT, BLOOM)

Cross Attention

Support for encoder-decoder models:
# Cross attention is automatically used in encoder-decoder architectures
# (e.g., T5, BART, Whisper)

Sliding Window Attention

Limited attention windows with cyclic KV cache:
from tensorrt_llm.llmapi import KvCacheConfig

kv_cache_config = KvCacheConfig(
    max_attention_window=[2048]  # Only attend to last 2048 tokens
)

llm = LLM(
    model="mistralai/Mistral-7B-v0.1",
    kv_cache_config=kv_cache_config
)
When input length exceeds attention_window_size, sliding window attention is automatically activated in the context phase.
Beam search is supported with cache indirection:
from tensorrt_llm.sampling_params import SamplingParams

sampling_params = SamplingParams(
    beam_width=4,
    num_return_sequences=4
)

outputs = llm.generate(prompts, sampling_params)
The cache_indirection tensor (shape [batch_size, beam_width, max_seqlen]) tracks which beam path to read KV cache from at each token position.

Performance Tuning

  • TRT-LLM backend: Best overall performance, recommended for production
  • FlashInfer backend: Good alternative, may be faster for specific workloads
  • Vanilla backend: Only for debugging and validation
  • Use paged KV cache for better memory efficiency
  • Enable FP8 KV cache on Hopper+ GPUs (2x memory reduction)
  • Configure max_attention_window for models with limited attention
  • Enable host offloading for long-running sessions
  • Enable chunked context for very long inputs
  • Use FP8 quantization for context FMHA (Ada, Hopper, Blackwell)
  • Batch context requests together when possible
  • XQA kernels automatically optimize MQA/GQA models
  • Multi-block mode improves performance at low batch sizes
  • Use in-flight batching to mix context and generation phases

Code Example: Custom Attention Configuration

from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo

# Configure FP8 quantization
quant_config = QuantConfig(
    quant_algo=QuantAlgo.FP8,
    kv_cache_quant_algo=QuantAlgo.FP8
)

# Configure KV cache with sliding window
kv_cache_config = KvCacheConfig(
    dtype='fp8',                      # FP8 KV cache
    free_gpu_memory_fraction=0.9,     # Use 90% of free GPU memory
    max_attention_window=[4096],      # 4K sliding window
    enable_block_reuse=True,          # Cross-request reuse
    host_cache_size=2*1024**3         # 2GB host offload
)

# Initialize with optimized attention
llm = LLM(
    model="mistralai/Mistral-7B-v0.1",
    attn_backend="trtllm",            # Use TRT-LLM backend
    quant_config=quant_config,
    kv_cache_config=kv_cache_config
)

outputs = llm.generate(
    "Explain the theory of relativity in simple terms.",
    max_tokens=500
)

Additional Resources

Flash Attention Paper

Original Flash Attention algorithm

Flash Attention 2 Paper

Improved Flash Attention with better parallelism

GQA Paper

Grouped-Query Attention for efficient inference

KV Cache Documentation

Detailed KV cache configuration guide

Build docs developers (and LLMs) love