Skip to main content

Overview

Nanochat uses a unified Flash Attention interface that automatically selects the best implementation based on available hardware:
  • Flash Attention 3 (FA3): On Hopper GPUs (H100, H200) - fastest
  • PyTorch SDPA: Fallback for Ada, Blackwell, Ampere, MPS, and CPU
This provides optimal performance across different hardware without code changes.

Key Features

  • Drop-in replacement for FA3 with identical API
  • Zero-overhead hardware detection (once at import time)
  • Supports causal attention and sliding window attention
  • KV cache support for inference
  • Automatic layout conversion for SDPA fallback
Reference: flash_attention.py:1-15

Usage

from nanochat.flash_attention import flash_attn

# Training (no KV cache)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)

# Inference (with KV cache)
y = flash_attn.flash_attn_with_kvcache(
    q, k_cache, v_cache,
    k=k, v=v,
    cache_seqlens=cache_seqlens,
    causal=True,
    window_size=window_size,
)
The API is identical whether FA3 or SDPA is used under the hood.

Hardware Detection

Loading FA3

FA3 is loaded at import time if conditions are met:
def _load_flash_attention_3():
    if not torch.cuda.is_available():
        return None
    
    major, _ = torch.cuda.get_device_capability()
    
    # FA3 only on Hopper (sm90)
    if major != 9:
        return None
    
    from kernels import get_kernel
    return get_kernel('varunneal/flash-attention-3').flash_attn_interface
Supported Hardware:
  • ✅ Hopper (H100, H200) - compute capability 9.0
  • ❌ Blackwell - compute capability 10.0 (needs recompilation)
  • ❌ Ada (RTX 4090) - compute capability 8.9
  • ❌ Ampere (A100) - compute capability 8.0
Reference: flash_attention.py:23-38

Why Not Blackwell?

FA3 kernels are compiled specifically for sm90 (Hopper). Blackwell (sm100) requires kernel recompilation. The SDPA fallback provides good performance until FA3 adds Blackwell support.

Detection Result

from nanochat.flash_attention import HAS_FA3

if HAS_FA3:
    print("Using Flash Attention 3")
else:
    print("Using PyTorch SDPA fallback")
Reference: flash_attention.py:42

API Reference

flash_attn_func

flash_attn.flash_attn_func(
    q,              # (B, T, H, D) - queries
    k,              # (B, T, H_kv, D) - keys
    v,              # (B, T, H_kv, D) - values
    causal=False,   # Use causal masking
    window_size=(-1, -1),  # (left, right) window
)
Tensor Layout: (batch, sequence, heads, dim) - FA3’s native layout Window Size:
  • (-1, 0): Full causal attention
  • (N, 0): Sliding window, attend to last N+1 tokens
  • (-1, -1): Full bidirectional (not causal)
Returns: Output tensor of shape (B, T, H, D) Reference: flash_attention.py:99-120

flash_attn_with_kvcache

flash_attn.flash_attn_with_kvcache(
    q,                  # (B, T_new, H, D) - queries
    k_cache,            # (B, T_max, H_kv, D) - key cache
    v_cache,            # (B, T_max, H_kv, D) - value cache
    k=None,             # (B, T_new, H_kv, D) - new keys
    v=None,             # (B, T_new, H_kv, D) - new values
    cache_seqlens=None, # (B,) - current position in cache
    causal=False,
    window_size=(-1, -1),
)
In-place Updates: Both FA3 and SDPA versions update k_cache and v_cache in-place. Cache Management:
  • k_cache, v_cache: Pre-allocated tensors (usually T_max = 4096 or larger)
  • cache_seqlens: Tensor tracking current position (shape: (B,))
  • New keys/values inserted at position cache_seqlens[0]:cache_seqlens[0]+T_new
Returns: Output tensor of shape (B, T_new, H, D) Reference: flash_attention.py:123-169

SDPA Fallback Implementation

Layout Conversion

FA3 uses (B, T, H, D) layout, but SDPA expects (B, H, T, D):
# Convert to SDPA layout
q = q.transpose(1, 2)  # (B, T, H, D) -> (B, H, T, D)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Call SDPA
y = F.scaled_dot_product_attention(q, k, v, ...)

# Convert back
y = y.transpose(1, 2)  # (B, H, T, D) -> (B, T, H, D)
Reference: flash_attention.py:115-120

Sliding Window Support

SDPA doesn’t natively support sliding windows, so we build an explicit mask:
def _sdpa_attention(q, k, v, window_size, enable_gqa):
    Tq, Tk = q.size(2), k.size(2)
    window = window_size[0]  # left window size
    
    # Full context, same length → use is_causal=True
    if (window < 0 or window >= Tq) and Tq == Tk:
        return F.scaled_dot_product_attention(q, k, v, is_causal=True)
    
    # Single token (generation) → no mask needed
    if Tq == 1:
        if window >= 0 and window < Tk:
            k = k[:, :, -(window+1):, :]  # Slice to window
            v = v[:, :, -(window+1):, :]
        return F.scaled_dot_product_attention(q, k, v, is_causal=False)
    
    # Sliding window or chunk inference → explicit mask
    row_idx = (Tk - Tq) + torch.arange(Tq).unsqueeze(1)
    col_idx = torch.arange(Tk).unsqueeze(0)
    mask = col_idx <= row_idx  # Causal mask
    
    if window >= 0:
        mask = mask & ((row_idx - col_idx) <= window)  # Sliding window
    
    return F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
Reference: flash_attention.py:61-94

GQA Support

SDPA has native GQA support (enabled automatically):
enable_gqa = q.size(1) != k.size(1)  # Different number of heads
y = F.scaled_dot_product_attention(q, k, v, enable_gqa=enable_gqa)
Reference: flash_attention.py:72, flash_attention.py:118

KV Cache Pattern

Typical usage in GPT model:
class CausalSelfAttention(nn.Module):
    def forward(self, x, ..., window_size, kv_cache):
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
        
        # ... apply RoPE and QK norm ...
        
        if kv_cache is None:
            # Training: no cache
            y = flash_attn.flash_attn_func(
                q, k, v, causal=True, window_size=window_size
            )
        else:
            # Inference: use cache
            k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
            y = flash_attn.flash_attn_with_kvcache(
                q, k_cache, v_cache,
                k=k, v=v,
                cache_seqlens=kv_cache.cache_seqlens,
                causal=True,
                window_size=window_size,
            )
            # Advance position after last layer
            if self.layer_idx == kv_cache.n_layers - 1:
                kv_cache.advance(T)
        
        return y
Reference: gpt.py:98-113

Performance Characteristics

FA3 (Hopper)

  • Speed: ~3x faster than SDPA on H100
  • Memory: More efficient (FlashAttention algorithm)
  • Precision: BFloat16
  • Limitations: Hopper-only (sm90)

SDPA Fallback

  • Speed: Good performance on all hardware
  • Memory: Standard memory-efficient attention
  • Precision: Adapts to input dtype
  • Coverage: Works everywhere (CUDA, MPS, CPU)

Testing and Override

For testing, you can force a specific implementation:
import nanochat.flash_attention as fa_module

# Force SDPA (even on Hopper)
fa_module._override_impl = 'sdpa'

# Force FA3 (will assert if not available)
fa_module._override_impl = 'fa3'

# Auto-detect (default)
fa_module._override_impl = None
Reference: flash_attention.py:45-55

Common Patterns

Training (Full Sequences)

# Shapes: (B=32, T=2048, H=12, D=64)
y = flash_attn.flash_attn_func(
    q, k, v,
    causal=True,
    window_size=(-1, 0),  # Full context
)

Training (Sliding Window)

# Attend to last 1024 tokens only
y = flash_attn.flash_attn_func(
    q, k, v,
    causal=True,
    window_size=(1024, 0),
)

Inference (Single Token)

# q: (B=1, T_new=1, H=12, D=64)
# k_cache, v_cache: (B=1, T_max=4096, H_kv=4, D=64)
# cache_seqlens: (B=1,) = [current_position]

y = flash_attn.flash_attn_with_kvcache(
    q, k_cache, v_cache,
    k=k, v=v,  # (B=1, 1, H_kv=4, D=64)
    cache_seqlens=cache_seqlens,
    causal=True,
    window_size=(-1, 0),
)

cache_seqlens += 1  # Advance position

Inference (Chunk)

# Process multiple tokens at once (e.g., prompt encoding)
# q: (B=1, T_new=128, H=12, D=64)

y = flash_attn.flash_attn_with_kvcache(
    q, k_cache, v_cache,
    k=k, v=v,  # (B=1, 128, H_kv=4, D=64)
    cache_seqlens=cache_seqlens,
    causal=True,
    window_size=(-1, 0),
)

cache_seqlens += 128  # Advance by chunk size

Advantages of Unified Interface

  1. Write once, run anywhere: Same code works on H100, A100, RTX 4090, Mac, etc.
  2. No conditional logic in model: Model code doesn’t need to check hardware
  3. Easy testing: Test SDPA path on Hopper by overriding
  4. Future-proof: When FA3 supports Blackwell, no code changes needed

Limitations

FA3 Limitations

  • Hopper-only (H100, H200)
  • BFloat16 only
  • Requires kernels package (varunneal/flash-attention-3)

SDPA Limitations

  • Slower than FA3 on Hopper
  • Sliding window requires explicit mask (memory overhead)
  • Chunk inference needs careful mask construction

GPT Architecture

How the model uses Flash Attention

Optimizer

MuonAdamW optimizer details

Build docs developers (and LLMs) love