Skip to main content
vLLM’s memory management system is built around PagedAttention, a novel attention algorithm that enables efficient management of key-value (KV) caches during LLM inference.
This document is based on the original vLLM paper. The current implementation may differ from the description here, but the core concepts remain the same.

Overview

PagedAttention is vLLM’s custom implementation of a multi-head query attention kernel that is designed to be compatible with vLLM’s paged KV caches. Key features:
  • Key and value caches are stored in separate blocks
  • Specially designed memory layout for high performance
  • Efficient memory access patterns for GPU threads

Core concepts

Before diving into the implementation, let’s understand some key terminology:

Sequence

A sequence represents a client request. For example, the data pointed to by query q has a shape of [num_seqs, num_heads, head_size], where num_seqs equals the total number of tokens processed in the batch.

Context

The context consists of the generated tokens from the sequence. For example, ["What", "is", "your"] are context tokens, and the input query token is "name". The model might generate the token "?".

Vec

A vec is a list of elements that are fetched and calculated together:
  • Query/Key vec: VEC_SIZE determined so each thread group can fetch 16 bytes at a time
  • Value vec: V_VEC_SIZE determined so each thread can fetch 16 bytes at a time
Example: If scalar_t is FP16 (2 bytes) and THREAD_GROUP_SIZE is 2:
  • VEC_SIZE = 4
  • V_VEC_SIZE = 8

Thread group

A small group of threads (THREAD_GROUP_SIZE) that fetches and calculates one query token and one key token at a time. Each thread handles only a portion of the token data. The total number of elements processed by one thread group is referred to as x. Example: If the thread group contains 2 threads and head size is 8:
  • Thread 0 handles elements at indices 0, 2, 4, 6
  • Thread 1 handles elements at indices 1, 3, 5, 7

Block

The key and value cache data in vLLM are split into blocks. Each block stores data for a fixed number (BLOCK_SIZE) of tokens at one head. Each block may contain only a portion of the whole context tokens. Example: If block size is 16 and head size is 128, then for one head, one block can store 16 × 128 = 2048 elements.

Warp

A warp is a group of 32 threads (WARP_SIZE) that execute simultaneously on a stream multiprocessor (SM). In the PagedAttention kernel:
  • Each warp processes the calculation between one query token and key tokens of one entire block at a time
  • May process multiple blocks in multiple iterations
Example: If there are 4 warps and 6 blocks for one context:
  • Warp 0 handles blocks 0, 4
  • Warp 1 handles blocks 1, 5
  • Warp 2 handles block 2
  • Warp 3 handles block 3

Thread block

A thread block is a group of threads (NUM_THREADS) that can access the same shared memory. Each thread block:
  • Contains multiple warps (NUM_WARPS)
  • Processes the calculation between one query token and key tokens of a whole context

Grid

A grid is a collection of thread blocks that defines the shape of the collection. In the PagedAttention kernel, the grid shape is (num_heads, num_seqs, max_num_partitions). Therefore, each thread block only handles the calculation for one head, one sequence, and one partition.

Memory layout

Query memory layout

Each thread defines its own q_ptr which points to the assigned query token data on global memory. Example: If VEC_SIZE is 4 and HEAD_SIZE is 128, the q_ptr points to data that contains 128 elements divided into 128 / 4 = 32 vecs. Shared memory storage:
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
Each vec is assigned to a different row. For example, if THREAD_GROUP_SIZE is 2:
  • Thread 0 handles row 0 vecs
  • Thread 1 handles row 1 vecs
This allows neighboring threads to read neighboring memory, achieving memory coalescing to improve performance.

Key memory layout

Unlike query processing, each thread group may handle multiple key tokens across multiple iterations. The k_ptr in each thread points to different key tokens at different iterations:
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
                    + kv_head_idx * kv_head_stride
                    + physical_block_offset * x;
Register memory storage:
K_vec k_vecs[NUM_VECS_PER_THREAD]
We use register memory for k_vecs because it will only be accessed by one thread once, whereas q_vecs will be accessed by multiple threads multiple times.

Value memory layout

Unlike query and key, there is no thread group concept for value data. Key differences:
  • Elements from the same column correspond to the same value token
  • Each thread always fetches V_VEC_SIZE elements from the same V_VEC_SIZE tokens at a time
  • A single thread retrieves multiple v_vecs from different rows and the same columns through multiple inner iterations

Computation flow

1. QK computation

The query-key computation follows this pattern:
q_vecs = ...  // Fetch query data once
for ... {  // Outer loop: iterate over different tokens
    k_ptr = ...
    for ... {  // Inner loop: prepare k_vecs
        k_vecs[i] = ...
    }
    ...
    // Dot product with cross-thread group reduction
    float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
        q_vecs[thread_group_offset], k_vecs
    );
}
Each thread only fetches part of the query and key token data, but cross-thread group reduction in Qk_dot<>::dot produces the full result between entire query and key token data.

2. Softmax computation

We calculate the normalized softmax for all QK scores: m(x):=maxixif(x):=[ex1m(x)exBm(x)](x):=if(x)isoftmax(x):=f(x)(x)\begin{gather*} m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} \end{gather*} Steps:
1

Collect QK max

After computing each qk result, set the temporary logits result and collect qk_max:
if (thread_group_offset == 0) {
    const bool mask = token_idx >= context_len;
    logits[token_idx - start_token_idx] = mask ? 0.f : qk;
    qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
2

Reduce QK max across warps

Get the reduced qk_max across each warp using warp-level communication:
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
3

Reduce across thread block

Get the final reduced qk_max from the whole thread block and broadcast to all threads:
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
4

Compute exp sum

Compute the exponential sum and convert logits:
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
    float val = __expf(logits[i] - qk_max);
    logits[i] = val;
    exp_sum += val;
}
...
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
5

Normalize

With the reduced qk_max and exp_sum, obtain the final normalized softmax result:
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
    logits[i] *= inv_sum;
}

3. LV computation

The logits-value computation follows this pattern:
float accs[NUM_ROWS_PER_THREAD];
for ... {  // Iteration over different blocks
    logits_vec = ...
    for ... {  // Iteration over different rows
        v_vec = ...
        ...
        accs[i] += dot(logits_vec, v_vec);
    }
}
Each entry of accs is mapped to a head position assigned to the current thread. Example: If BLOCK_SIZE is 16 and V_VEC_SIZE is 8:
  • Each thread fetches 8 value elements for 8 tokens at a time
  • Each element is from different tokens at the same head position
  • If HEAD_SIZE is 128 and WARP_SIZE is 32, each inner loop processes 32 × 8 = 256 elements
  • Total of 128 × 16 / 256 = 8 inner iterations per warp for a whole block

4. Reduction and output

1

Warp-level reduction

Perform reduction for accs within each warp:
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    float acc = accs[i];
    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
        acc += VLLM_SHFL_XOR_SYNC(acc, mask);
    }
    accs[i] = acc;
}
2

Block-level reduction

Perform reduction across all warps, allowing each thread to have the accumulation for assigned head positions of all context tokens.
3

Write output

Write the final results from register memory to global memory:
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
                + head_idx * max_num_partitions * HEAD_SIZE
                + partition_idx * HEAD_SIZE;

for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
    if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        from_float(*(out_ptr + row_idx), accs[i]);
    }
}

Performance optimizations

Memory coalescing

Neighboring threads read neighboring memory locations together, which allows the GPU to combine multiple memory accesses into a single transaction.

Memory hierarchy

  • Shared memory: Used for q_vecs that are accessed by multiple threads multiple times
  • Register memory: Used for k_vecs and v_vecs that are accessed by a single thread once
  • Global memory: Used for input/output data

Efficient reduction operations

Warp-level primitives (VLLM_SHFL_XOR_SYNC, VLLM_SHFL_SYNC) are used for efficient reduction operations without requiring shared memory synchronization.

Benefits of PagedAttention

Memory efficiency

Reduces KV cache memory usage by up to 4x through paging

Dynamic allocation

Allows flexible memory sharing between sequences

High throughput

Increases serving throughput by 2-4x compared to baseline

No fragmentation

Eliminates memory fragmentation through block-based allocation

Citation

If you use vLLM’s PagedAttention in your research, please cite:
@inproceedings{kwon2023efficient,
  title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
  author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
  booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
  year={2023}
}

Next steps

Architecture

Learn about vLLM’s overall architecture

Plugin system

Extend vLLM with custom features

Build docs developers (and LLMs) love