Overview
Nanochat uses a unified Flash Attention interface that automatically selects the best implementation based on available hardware:- Flash Attention 3 (FA3): On Hopper GPUs (H100, H200) - fastest
- PyTorch SDPA: Fallback for Ada, Blackwell, Ampere, MPS, and CPU
Key Features
- Drop-in replacement for FA3 with identical API
- Zero-overhead hardware detection (once at import time)
- Supports causal attention and sliding window attention
- KV cache support for inference
- Automatic layout conversion for SDPA fallback
Usage
Hardware Detection
Loading FA3
FA3 is loaded at import time if conditions are met:- ✅ Hopper (H100, H200) - compute capability 9.0
- ❌ Blackwell - compute capability 10.0 (needs recompilation)
- ❌ Ada (RTX 4090) - compute capability 8.9
- ❌ Ampere (A100) - compute capability 8.0
Why Not Blackwell?
FA3 kernels are compiled specifically for sm90 (Hopper). Blackwell (sm100) requires kernel recompilation. The SDPA fallback provides good performance until FA3 adds Blackwell support.Detection Result
API Reference
flash_attn_func
(batch, sequence, heads, dim) - FA3’s native layout
Window Size:
(-1, 0): Full causal attention(N, 0): Sliding window, attend to last N+1 tokens(-1, -1): Full bidirectional (not causal)
(B, T, H, D)
Reference: flash_attention.py:99-120
flash_attn_with_kvcache
k_cache and v_cache in-place.
Cache Management:
k_cache,v_cache: Pre-allocated tensors (usually T_max = 4096 or larger)cache_seqlens: Tensor tracking current position (shape:(B,))- New keys/values inserted at position
cache_seqlens[0]:cache_seqlens[0]+T_new
(B, T_new, H, D)
Reference: flash_attention.py:123-169
SDPA Fallback Implementation
Layout Conversion
FA3 uses(B, T, H, D) layout, but SDPA expects (B, H, T, D):
Sliding Window Support
SDPA doesn’t natively support sliding windows, so we build an explicit mask:GQA Support
SDPA has native GQA support (enabled automatically):KV Cache Pattern
Typical usage in GPT model:Performance Characteristics
FA3 (Hopper)
- Speed: ~3x faster than SDPA on H100
- Memory: More efficient (FlashAttention algorithm)
- Precision: BFloat16
- Limitations: Hopper-only (sm90)
SDPA Fallback
- Speed: Good performance on all hardware
- Memory: Standard memory-efficient attention
- Precision: Adapts to input dtype
- Coverage: Works everywhere (CUDA, MPS, CPU)
Testing and Override
For testing, you can force a specific implementation:Common Patterns
Training (Full Sequences)
Training (Sliding Window)
Inference (Single Token)
Inference (Chunk)
Advantages of Unified Interface
- Write once, run anywhere: Same code works on H100, A100, RTX 4090, Mac, etc.
- No conditional logic in model: Model code doesn’t need to check hardware
- Easy testing: Test SDPA path on Hopper by overriding
- Future-proof: When FA3 supports Blackwell, no code changes needed
Limitations
FA3 Limitations
- Hopper-only (H100, H200)
- BFloat16 only
- Requires
kernelspackage (varunneal/flash-attention-3)
SDPA Limitations
- Slower than FA3 on Hopper
- Sliding window requires explicit mask (memory overhead)
- Chunk inference needs careful mask construction
Related
GPT Architecture
How the model uses Flash Attention
Optimizer
MuonAdamW optimizer details