Overview
Flash Attention v2 computes scaled dot-product attention:- Q: Query matrix (B×Sq×N×H)
- K: Key matrix (B×Sk×N×H)
- V: Value matrix (B×Sk×N×H)
- O: Output matrix (B×Sq×N×H)
- B: batch size, Sq/Sk: sequence length, N: number of heads, H: head dimension
Key Features
- Online softmax: Computes softmax incrementally without materializing full attention matrix
- Tiled computation: Processes attention in blocks to fit in shared memory
- Register pipeline: Overlaps shared memory loads with computation
- Causal masking: Optional support for autoregressive models
Architecture
Implementation
import cutlass
import cutlass.cute as cute
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, warp
class FlashAttentionForwardAmpere:
def __init__(
self,
head_dim: int,
m_block_size: int = 128,
n_block_size: int = 128,
num_threads: int = 128,
is_causal: bool = False,
):
self._head_dim = head_dim
self._m_block_size = m_block_size
self._n_block_size = n_block_size
self._head_dim_padded = (head_dim + 31) // 32 * 32
self._num_threads = num_threads
self._is_causal = is_causal
self.cta_sync_barrier = pipeline.NamedBarrier(
barrier_id=1, num_threads=num_threads
)
# Determine swizzle parameters
smem_k_block_size = 64 if self._head_dim_padded % 64 == 0 else 32
swizzle_bits = 3 if smem_k_block_size == 64 else 2
# Create swizzled layout atom
sQ_layout_atom = cute.make_composed_layout(
cute.make_swizzle(swizzle_bits, 3, 3),
0,
cute.make_layout((8, smem_k_block_size), stride=(smem_k_block_size, 1)),
)
# Tile to full shape
sQ_layout = cute.tile_to_shape(
sQ_layout_atom,
(self._m_block_size, self._head_dim_padded),
(0, 1),
)
# K and V use same layout
sKV_layout = cute.tile_to_shape(
sQ_layout_atom,
(self._n_block_size, self._head_dim_padded),
(0, 1),
)
# Create MMA instruction
tiled_mma = cute.make_tiled_mma(
warp.MmaF16BF16Op(self._dtype, cutlass.Float32, (16, 8, 16)),
(self._num_threads // 32, 1, 1),
permutation_mnk=(self._num_threads // 32 * 16, 16, 16),
)
# Create async copy atom for Q/K/V
atom_async_copy = cute.make_copy_atom(
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
self._dtype,
num_bits_per_copy=128,
)
# Create tiled copy
gmem_tiled_copy_QKV = cute.make_tiled_copy_tv(
atom_async_copy, tQKV_layout, vQKV_layout
)
@cute.jit
def softmax_rescale_O(
self,
acc_S: cute.Tensor,
acc_O: cute.Tensor,
row_max: cute.Tensor,
row_sum: cute.Tensor,
is_first_n_block: bool,
):
# Convert to M×N view
acc_S_mn = self._make_acc_tensor_mn_view(acc_S)
acc_O_mn = self._make_acc_tensor_mn_view(acc_O)
# Process each row
for r in cutlass.range_constexpr(cute.size(row_max)):
# Load current row of S
acc_S_row = acc_S_mn[r, None].load()
# Compute new max
row_max_cur = acc_S_row.reduce(
cute.ReductionOp.MAX, -cutlass.Float32.inf, 0
)
row_max_cur = self._threadquad_reduce_max(row_max_cur)
# Update max if not first block
if not is_first_n_block:
row_max_prev = row_max[r]
row_max_cur = cute.arch.fmax(row_max_prev, row_max_cur)
# Compute exp(S - max)
acc_S_row_exp = cute.math.exp2(
acc_S_row * softmax_scale_log2 - row_max_cur * softmax_scale_log2,
fastmath=True,
)
# Update sum
acc_S_row_sum = acc_S_row_exp.reduce(
cute.ReductionOp.ADD, cutlass.Float32.zero, 0
)
if not is_first_n_block:
# Correction factor for previous blocks
prev_minus_cur_exp = cute.math.exp2(
row_max_prev * softmax_scale_log2 - row_max_cur * softmax_scale_log2,
fastmath=True,
)
acc_S_row_sum = acc_S_row_sum + row_sum[r] * prev_minus_cur_exp
# Rescale previous O
acc_O_mn[r, None] = acc_O_mn[r, None].load() * prev_minus_cur_exp
# Update running statistics
row_max[r] = row_max_cur
row_sum[r] = acc_S_row_sum
acc_S_mn[r, None] = acc_S_row_exp
@cute.kernel
def kernel(self, mQ, mK, mV, mO, softmax_scale_log2, ...):
tidx, _, _ = cute.arch.thread_idx()
m_block, batch_size, num_head = cute.arch.block_idx()
# Allocate shared memory
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
sQ = storage.sQ.get_tensor(sQ_layout)
sK = storage.sK.get_tensor(sKV_layout)
sV = storage.sV.get_tensor(sKV_layout)
# Initialize statistics
row_max = cute.make_rmem_tensor(
(acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32
)
row_sum = cute.make_rmem_tensor(
(acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32
)
row_max.fill(-cutlass.Float32.inf)
row_sum.fill(0.0)
# Load Q tile (stays constant)
cute.copy(gmem_tiled_copy_QKV, tQgQ, tQsQ)
cute.arch.cp_async_commit_group()
# Process K/V tiles
n_block_max = cute.ceil_div(mK.shape[1], self._n_block_size)
if self._is_causal:
n_block_max = min(
cute.ceil_div((m_block + 1) * self._m_block_size, self._n_block_size),
n_block_max,
)
for n_block in range(n_block_max):
# Load K, V tiles
cute.copy(gmem_tiled_copy_QKV, tKgK[n_block], tKsK)
cute.copy(gmem_tiled_copy_QKV, tVgV[n_block], tVsV)
cute.arch.cp_async_wait_group(0)
self.cta_sync_barrier.arrive_and_wait()
# Compute S = Q @ K^T
acc_S = cute.make_rmem_tensor(acc_shape_S, cutlass.Float32)
acc_S.fill(0.0)
cute.gemm(tiled_mma, acc_S, tSrQ, tSrK, acc_S)
# Apply causal mask if needed
if self._is_causal and n_block == n_block_max - 1:
self.apply_causal_mask(acc_S, m_block, n_block)
# Online softmax and rescale O
self.softmax_rescale_O(
acc_S, acc_O, row_max, row_sum, n_block == 0
)
# Convert to P and compute O += P @ V
rP = cute.make_fragment_like(acc_S, self._dtype)
rP.store(acc_S.load().to(self._dtype))
cute.gemm(tiled_mma, acc_O, rP, tOrVt, acc_O)
# Final normalization
self.normalize_softmax(acc_O, row_sum)
# Store output
cute.copy(gmem_tiled_copy_O, tOrO, tOgO)
def apply_causal_mask(self, acc_S, m_block, n_block):
acc_S_mn = self._make_acc_tensor_mn_view(acc_S)
for r in cutlass.range_constexpr(cute.size(acc_S_mn.shape[0])):
# Get row index limit
col_idx_limit = cutlass.min(
(m_block * self._m_block_size + r + 1),
mK.shape[1]
)
# Mask positions beyond limit
for c in cutlass.range_constexpr(cute.size(acc_S_mn.shape[1])):
col_idx = n_block * self._n_block_size + c
if col_idx >= col_idx_limit:
acc_S_mn[r, c] = -cutlass.Float32.inf
Running Examples
Performance Profiling
Key Concepts
Online Softmax Algorithm
Online Softmax Algorithm
The online softmax algorithm processes attention in blocks without materializing the full matrix:Traditional approach (memory intensive):Flash Attention approach (memory efficient):This reduces memory from O(N²) to O(N).
Register Pipeline
Register Pipeline
The register pipeline overlaps shared memory loads with computation:
Tile Size Selection
Tile Size Selection
Choosing optimal tile sizes:m_block_size (Query tile):
- Larger: Better arithmetic intensity, more registers
- Smaller: Lower SMEM usage, better for long sequences
- Typical: 64, 128
- Affects online softmax accuracy
- Should match m_block_size for balanced computation
- Typical: 64, 128
- Usually fixed by model (64, 128, 256)
- Must be 16-byte aligned
- Pad if necessary
(m_block_size * head_dim + 2 * n_block_size * head_dim) * sizeof(dtype) must fit in SMEMMemory Analysis
| Component | Standard Attention | Flash Attention v2 |
|---|---|---|
| S matrix | O(B·N·Sq·Sk) | O(B·N·m_block·n_block) |
| P matrix | O(B·N·Sq·Sk) | O(B·N·m_block·n_block) |
| Total GMEM | O(N²) | O(N) |
| SMEM usage | Low | High (Q+K+V tiles) |
Constraints
Source Code
Flash Attention v2 (Ampere)
Complete source:
examples/python/CuTeDSL/ampere/flash_attention_v2.pyFlash Attention (Hopper)
Hopper variant:
examples/python/CuTeDSL/hopper/fmha.py