This document is based on the original vLLM paper. The current implementation may differ from the description here, but the core concepts remain the same.
Overview
PagedAttention is vLLM’s custom implementation of a multi-head query attention kernel that is designed to be compatible with vLLM’s paged KV caches. Key features:- Key and value caches are stored in separate blocks
- Specially designed memory layout for high performance
- Efficient memory access patterns for GPU threads
Core concepts
Before diving into the implementation, let’s understand some key terminology:Sequence
A sequence represents a client request. For example, the data pointed to by queryq has a shape of [num_seqs, num_heads, head_size], where num_seqs equals the total number of tokens processed in the batch.
Context
The context consists of the generated tokens from the sequence. For example,["What", "is", "your"] are context tokens, and the input query token is "name". The model might generate the token "?".
Vec
A vec is a list of elements that are fetched and calculated together:- Query/Key vec:
VEC_SIZEdetermined so each thread group can fetch 16 bytes at a time - Value vec:
V_VEC_SIZEdetermined so each thread can fetch 16 bytes at a time
scalar_t is FP16 (2 bytes) and THREAD_GROUP_SIZE is 2:
VEC_SIZE= 4V_VEC_SIZE= 8
Thread group
A small group of threads (THREAD_GROUP_SIZE) that fetches and calculates one query token and one key token at a time. Each thread handles only a portion of the token data.
The total number of elements processed by one thread group is referred to as x.
Example: If the thread group contains 2 threads and head size is 8:
- Thread 0 handles elements at indices 0, 2, 4, 6
- Thread 1 handles elements at indices 1, 3, 5, 7
Block
The key and value cache data in vLLM are split into blocks. Each block stores data for a fixed number (BLOCK_SIZE) of tokens at one head.
Each block may contain only a portion of the whole context tokens.
Example: If block size is 16 and head size is 128, then for one head, one block can store 16 × 128 = 2048 elements.
Warp
A warp is a group of 32 threads (WARP_SIZE) that execute simultaneously on a stream multiprocessor (SM).
In the PagedAttention kernel:
- Each warp processes the calculation between one query token and key tokens of one entire block at a time
- May process multiple blocks in multiple iterations
- Warp 0 handles blocks 0, 4
- Warp 1 handles blocks 1, 5
- Warp 2 handles block 2
- Warp 3 handles block 3
Thread block
A thread block is a group of threads (NUM_THREADS) that can access the same shared memory.
Each thread block:
- Contains multiple warps (
NUM_WARPS) - Processes the calculation between one query token and key tokens of a whole context
Grid
A grid is a collection of thread blocks that defines the shape of the collection. In the PagedAttention kernel, the grid shape is(num_heads, num_seqs, max_num_partitions).
Therefore, each thread block only handles the calculation for one head, one sequence, and one partition.
Memory layout
Query memory layout
Each thread defines its ownq_ptr which points to the assigned query token data on global memory.
Example: If VEC_SIZE is 4 and HEAD_SIZE is 128, the q_ptr points to data that contains 128 elements divided into 128 / 4 = 32 vecs.
Shared memory storage:
THREAD_GROUP_SIZE is 2:
- Thread 0 handles row 0 vecs
- Thread 1 handles row 1 vecs
Key memory layout
Unlike query processing, each thread group may handle multiple key tokens across multiple iterations. Thek_ptr in each thread points to different key tokens at different iterations:
k_vecs because it will only be accessed by one thread once, whereas q_vecs will be accessed by multiple threads multiple times.
Value memory layout
Unlike query and key, there is no thread group concept for value data. Key differences:- Elements from the same column correspond to the same value token
- Each thread always fetches
V_VEC_SIZEelements from the sameV_VEC_SIZEtokens at a time - A single thread retrieves multiple
v_vecs from different rows and the same columns through multiple inner iterations
Computation flow
1. QK computation
The query-key computation follows this pattern:Qk_dot<>::dot produces the full result between entire query and key token data.
2. Softmax computation
We calculate the normalized softmax for all QK scores: Steps:Reduce across thread block
Get the final reduced
qk_max from the whole thread block and broadcast to all threads:3. LV computation
The logits-value computation follows this pattern:accs is mapped to a head position assigned to the current thread.
Example: If BLOCK_SIZE is 16 and V_VEC_SIZE is 8:
- Each thread fetches 8 value elements for 8 tokens at a time
- Each element is from different tokens at the same head position
- If
HEAD_SIZEis 128 andWARP_SIZEis 32, each inner loop processes32 × 8 = 256elements - Total of
128 × 16 / 256 = 8inner iterations per warp for a whole block
4. Reduction and output
Block-level reduction
Perform reduction across all warps, allowing each thread to have the accumulation for assigned head positions of all context tokens.
Performance optimizations
Memory coalescing
Neighboring threads read neighboring memory locations together, which allows the GPU to combine multiple memory accesses into a single transaction.Memory hierarchy
- Shared memory: Used for
q_vecsthat are accessed by multiple threads multiple times - Register memory: Used for
k_vecsandv_vecsthat are accessed by a single thread once - Global memory: Used for input/output data
Efficient reduction operations
Warp-level primitives (VLLM_SHFL_XOR_SYNC, VLLM_SHFL_SYNC) are used for efficient reduction operations without requiring shared memory synchronization.
Benefits of PagedAttention
Memory efficiency
Reduces KV cache memory usage by up to 4x through paging
Dynamic allocation
Allows flexible memory sharing between sequences
High throughput
Increases serving throughput by 2-4x compared to baseline
No fragmentation
Eliminates memory fragmentation through block-based allocation
Citation
If you use vLLM’s PagedAttention in your research, please cite:Next steps
Architecture
Learn about vLLM’s overall architecture
Plugin system
Extend vLLM with custom features