Skip to main content
This guide demonstrates how to implement Flash Attention v2 using CUTLASS CuTe DSL. Flash Attention is a memory-efficient attention mechanism that reduces memory usage from O(N²) to O(N) through online softmax computation.

Overview

Flash Attention v2 computes scaled dot-product attention:
O = softmax(Q * K^T / sqrt(d)) * V
Where:
  • Q: Query matrix (B×Sq×N×H)
  • K: Key matrix (B×Sk×N×H)
  • V: Value matrix (B×Sk×N×H)
  • O: Output matrix (B×Sq×N×H)
  • B: batch size, Sq/Sk: sequence length, N: number of heads, H: head dimension

Key Features

  • Online softmax: Computes softmax incrementally without materializing full attention matrix
  • Tiled computation: Processes attention in blocks to fit in shared memory
  • Register pipeline: Overlaps shared memory loads with computation
  • Causal masking: Optional support for autoregressive models

Architecture

┌─────────┐  ┌─────────┐
│    Q    │  │  K, V   │
└────┬────┘  └────┬────┘
     │            │
     ├─CpAsync────┤
     ↓            ↓
  ┌────────────────────┐
  │  Shared Memory     │
  │  sQ    sK    sV    │
  └──┬──────┬──────┬───┘
     │      │      │
     ↓      ↓      ↓
  ┌────────────────────┐
  │  Tensor Core MMA   │
  │    S = Q×K^T       │
  └─────────┬──────────┘

  ┌──────────────────────┐
  │  Online Softmax      │
  │  - Update row_max    │
  │  - Update row_sum    │
  │  - Rescale prev O    │
  └─────────┬────────────┘

  ┌──────────────────────┐
  │  Tensor Core MMA     │
  │    O = P×V           │
  └─────────┬────────────┘

  ┌──────────────────────┐
  │  Normalize & Store   │
  └──────────────────────┘

Implementation

1
Define the Attention Kernel
2
Create the Flash Attention class:
3
import cutlass
import cutlass.cute as cute
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, warp

class FlashAttentionForwardAmpere:
    def __init__(
        self,
        head_dim: int,
        m_block_size: int = 128,
        n_block_size: int = 128,
        num_threads: int = 128,
        is_causal: bool = False,
    ):
        self._head_dim = head_dim
        self._m_block_size = m_block_size
        self._n_block_size = n_block_size
        self._head_dim_padded = (head_dim + 31) // 32 * 32
        self._num_threads = num_threads
        self._is_causal = is_causal
        
        self.cta_sync_barrier = pipeline.NamedBarrier(
            barrier_id=1, num_threads=num_threads
        )
4
Configure Shared Memory Layouts
5
Set up swizzled layouts for Q, K, V:
6
# Determine swizzle parameters
smem_k_block_size = 64 if self._head_dim_padded % 64 == 0 else 32
swizzle_bits = 3 if smem_k_block_size == 64 else 2

# Create swizzled layout atom
sQ_layout_atom = cute.make_composed_layout(
    cute.make_swizzle(swizzle_bits, 3, 3),
    0,
    cute.make_layout((8, smem_k_block_size), stride=(smem_k_block_size, 1)),
)

# Tile to full shape
sQ_layout = cute.tile_to_shape(
    sQ_layout_atom,
    (self._m_block_size, self._head_dim_padded),
    (0, 1),
)

# K and V use same layout
sKV_layout = cute.tile_to_shape(
    sQ_layout_atom,
    (self._n_block_size, self._head_dim_padded),
    (0, 1),
)
7
Create Tiled MMA and Copy Operations
8
Define MMA and copy atoms:
9
# Create MMA instruction
tiled_mma = cute.make_tiled_mma(
    warp.MmaF16BF16Op(self._dtype, cutlass.Float32, (16, 8, 16)),
    (self._num_threads // 32, 1, 1),
    permutation_mnk=(self._num_threads // 32 * 16, 16, 16),
)

# Create async copy atom for Q/K/V
atom_async_copy = cute.make_copy_atom(
    cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
    self._dtype,
    num_bits_per_copy=128,
)

# Create tiled copy
gmem_tiled_copy_QKV = cute.make_tiled_copy_tv(
    atom_async_copy, tQKV_layout, vQKV_layout
)
10
Implement Online Softmax
11
The key innovation: incremental softmax computation:
12
@cute.jit
def softmax_rescale_O(
    self,
    acc_S: cute.Tensor,
    acc_O: cute.Tensor,
    row_max: cute.Tensor,
    row_sum: cute.Tensor,
    is_first_n_block: bool,
):
    # Convert to M×N view
    acc_S_mn = self._make_acc_tensor_mn_view(acc_S)
    acc_O_mn = self._make_acc_tensor_mn_view(acc_O)
    
    # Process each row
    for r in cutlass.range_constexpr(cute.size(row_max)):
        # Load current row of S
        acc_S_row = acc_S_mn[r, None].load()
        
        # Compute new max
        row_max_cur = acc_S_row.reduce(
            cute.ReductionOp.MAX, -cutlass.Float32.inf, 0
        )
        row_max_cur = self._threadquad_reduce_max(row_max_cur)
        
        # Update max if not first block
        if not is_first_n_block:
            row_max_prev = row_max[r]
            row_max_cur = cute.arch.fmax(row_max_prev, row_max_cur)
        
        # Compute exp(S - max)
        acc_S_row_exp = cute.math.exp2(
            acc_S_row * softmax_scale_log2 - row_max_cur * softmax_scale_log2,
            fastmath=True,
        )
        
        # Update sum
        acc_S_row_sum = acc_S_row_exp.reduce(
            cute.ReductionOp.ADD, cutlass.Float32.zero, 0
        )
        
        if not is_first_n_block:
            # Correction factor for previous blocks
            prev_minus_cur_exp = cute.math.exp2(
                row_max_prev * softmax_scale_log2 - row_max_cur * softmax_scale_log2,
                fastmath=True,
            )
            acc_S_row_sum = acc_S_row_sum + row_sum[r] * prev_minus_cur_exp
            # Rescale previous O
            acc_O_mn[r, None] = acc_O_mn[r, None].load() * prev_minus_cur_exp
        
        # Update running statistics
        row_max[r] = row_max_cur
        row_sum[r] = acc_S_row_sum
        acc_S_mn[r, None] = acc_S_row_exp
13
Main Computation Loop
14
Process K/V in blocks:
15
@cute.kernel
def kernel(self, mQ, mK, mV, mO, softmax_scale_log2, ...):
    tidx, _, _ = cute.arch.thread_idx()
    m_block, batch_size, num_head = cute.arch.block_idx()
    
    # Allocate shared memory
    smem = cutlass.utils.SmemAllocator()
    storage = smem.allocate(SharedStorage)
    sQ = storage.sQ.get_tensor(sQ_layout)
    sK = storage.sK.get_tensor(sKV_layout)
    sV = storage.sV.get_tensor(sKV_layout)
    
    # Initialize statistics
    row_max = cute.make_rmem_tensor(
        (acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32
    )
    row_sum = cute.make_rmem_tensor(
        (acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32
    )
    row_max.fill(-cutlass.Float32.inf)
    row_sum.fill(0.0)
    
    # Load Q tile (stays constant)
    cute.copy(gmem_tiled_copy_QKV, tQgQ, tQsQ)
    cute.arch.cp_async_commit_group()
    
    # Process K/V tiles
    n_block_max = cute.ceil_div(mK.shape[1], self._n_block_size)
    if self._is_causal:
        n_block_max = min(
            cute.ceil_div((m_block + 1) * self._m_block_size, self._n_block_size),
            n_block_max,
        )
    
    for n_block in range(n_block_max):
        # Load K, V tiles
        cute.copy(gmem_tiled_copy_QKV, tKgK[n_block], tKsK)
        cute.copy(gmem_tiled_copy_QKV, tVgV[n_block], tVsV)
        cute.arch.cp_async_wait_group(0)
        self.cta_sync_barrier.arrive_and_wait()
        
        # Compute S = Q @ K^T
        acc_S = cute.make_rmem_tensor(acc_shape_S, cutlass.Float32)
        acc_S.fill(0.0)
        cute.gemm(tiled_mma, acc_S, tSrQ, tSrK, acc_S)
        
        # Apply causal mask if needed
        if self._is_causal and n_block == n_block_max - 1:
            self.apply_causal_mask(acc_S, m_block, n_block)
        
        # Online softmax and rescale O
        self.softmax_rescale_O(
            acc_S, acc_O, row_max, row_sum, n_block == 0
        )
        
        # Convert to P and compute O += P @ V
        rP = cute.make_fragment_like(acc_S, self._dtype)
        rP.store(acc_S.load().to(self._dtype))
        cute.gemm(tiled_mma, acc_O, rP, tOrVt, acc_O)
    
    # Final normalization
    self.normalize_softmax(acc_O, row_sum)
    
    # Store output
    cute.copy(gmem_tiled_copy_O, tOrO, tOgO)
16
Apply Causal Masking
17
Mask future positions for autoregressive models:
18
def apply_causal_mask(self, acc_S, m_block, n_block):
    acc_S_mn = self._make_acc_tensor_mn_view(acc_S)
    
    for r in cutlass.range_constexpr(cute.size(acc_S_mn.shape[0])):
        # Get row index limit
        col_idx_limit = cutlass.min(
            (m_block * self._m_block_size + r + 1),
            mK.shape[1]
        )
        
        # Mask positions beyond limit
        for c in cutlass.range_constexpr(cute.size(acc_S_mn.shape[1])):
            col_idx = n_block * self._n_block_size + c
            if col_idx >= col_idx_limit:
                acc_S_mn[r, c] = -cutlass.Float32.inf

Running Examples

python examples/python/CuTeDSL/ampere/flash_attention_v2.py \
  --dtype Float16 \
  --head_dim 128 \
  --m_block_size 128 \
  --n_block_size 128 \
  --num_threads 128 \
  --batch_size 1 \
  --seqlen_q 1280 \
  --seqlen_k 1536 \
  --num_head 16 \
  --softmax_scale 1.0

Performance Profiling

ncu python examples/python/CuTeDSL/ampere/flash_attention_v2.py \
  --dtype Float16 --head_dim 128 \
  --m_block_size 128 --n_block_size 128 \
  --num_threads 128 --batch_size 1 \
  --seqlen_q 1280 --seqlen_k 1536 \
  --num_head 16 --softmax_scale 1.0 \
  --is_causal --skip_ref_check

Key Concepts

The online softmax algorithm processes attention in blocks without materializing the full matrix:Traditional approach (memory intensive):
S = Q @ K^T              # O(N²) memory
P = softmax(S)           # O(N²) memory
O = P @ V                # O(N²) memory
Flash Attention approach (memory efficient):
For each K/V block:
  1. Compute S_block = Q @ K_block^T
  2. Update row_max = max(row_max_prev, max(S_block))
  3. Rescale previous O by exp(row_max_prev - row_max)
  4. Update row_sum with new exp values
  5. Compute O += softmax(S_block) @ V_block

Finally: O = O / row_sum
This reduces memory from O(N²) to O(N).
The register pipeline overlaps shared memory loads with computation:
# Load first k-block
cute.copy(smem_tiled_copy_Q, tSsQ[0], tSrQ_view[0])
cute.copy(smem_tiled_copy_K, tSsK[0], tSrK_view[0])

for k in range(num_k_blocks):
    # Load next block while computing current
    k_next = (k + 1) % num_k_blocks
    cute.copy(smem_tiled_copy_Q, tSsQ[k_next], tSrQ_view[k_next])
    cute.copy(smem_tiled_copy_K, tSsK[k_next], tSrK_view[k_next])
    
    # Compute with current block
    cute.gemm(tiled_mma, acc_S, tSrQ[k], tSrK[k], acc_S)
Choosing optimal tile sizes:m_block_size (Query tile):
  • Larger: Better arithmetic intensity, more registers
  • Smaller: Lower SMEM usage, better for long sequences
  • Typical: 64, 128
n_block_size (Key/Value tile):
  • Affects online softmax accuracy
  • Should match m_block_size for balanced computation
  • Typical: 64, 128
head_dim:
  • Usually fixed by model (64, 128, 256)
  • Must be 16-byte aligned
  • Pad if necessary
Constraint: (m_block_size * head_dim + 2 * n_block_size * head_dim) * sizeof(dtype) must fit in SMEM

Memory Analysis

ComponentStandard AttentionFlash Attention v2
S matrixO(B·N·Sq·Sk)O(B·N·m_block·n_block)
P matrixO(B·N·Sq·Sk)O(B·N·m_block·n_block)
Total GMEMO(N²)O(N)
SMEM usageLowHigh (Q+K+V tiles)

Constraints

  • Only fp16 and bf16 data types supported
  • Head dimension must be 16-byte aligned (multiple of 8 elements)
  • m_block_size * 2 must be divisible by num_threads
  • Total SMEM: (m_block + 2*n_block) * head_dim * sizeof(dtype) < capacity
  • Log-sum-exp for training backward pass not computed

Source Code

Flash Attention v2 (Ampere)

Complete source: examples/python/CuTeDSL/ampere/flash_attention_v2.py

Flash Attention (Hopper)

Hopper variant: examples/python/CuTeDSL/hopper/fmha.py

Build docs developers (and LLMs) love