Overview
FP8 (8-bit floating point) training can speed up training by approximately 2x on modern GPUs (H100+) while maintaining model quality. nanochat includes a minimal FP8 implementation that serves as a drop-in replacement for standard linear layers.Requirements
- Hardware: NVIDIA H100 or newer GPU with FP8 hardware support
- Software: PyTorch with
torch._scaled_mmsupport - nanochat’s custom FP8 implementation (in
nanochat/fp8.py)
Quick Start
Enable FP8 training by adding the--fp8 flag to your training command:
How FP8 Works
FP8 training wraps each matrix multiplication (matmul) with quantization/dequantization:- Compute scale:
scale = FP8_MAX / max(|tensor|)for each operand - Quantize: Convert tensor to FP8 format with clamping
- Matmul: Use
torch._scaled_mm(cuBLAS FP8 kernel, ~2x faster than BF16) - Dequantize:
_scaled_mmhandles this internally using inverse scales
- Forward:
output = input @ weight.T - Backward (grad_input):
grad_input = grad_output @ weight - Backward (grad_weight):
grad_weight = grad_output.T @ input
FP8 Data Types
nanochat uses both FP8 formats following standard conventions:-
float8_e4m3fn: 4-bit exponent, 3-bit mantissa, range [-448, 448]
- Higher precision (more mantissa bits)
- Used for input and weight tensors
-
float8_e5m2: 5-bit exponent, 2-bit mantissa, range [-57344, 57344]
- Wider range (more exponent bits)
- Used for gradient tensors (which can be larger)
Scaling Recipes
The--fp8-recipe flag controls the scaling strategy:
tensorwise (default, recommended)
- One scalar scale per entire tensor
- Faster: cuBLAS handles scaling directly
- ~150 lines of code in nanochat
- Used in nanochat’s speedrun configuration
rowwise
- Separate scale per row
- More accurate but slower (requires CUTLASS kernel)
- Requires full torchao library (not included in nanochat’s minimal implementation)
Implementation Details
nanochat includes a minimal ~150-line FP8 implementation innanochat/fp8.py that replaces torchao’s ~2000-line implementation:
--fp8. The filter ensures only suitable layers are converted:
- Dimensions must be divisible by 16 (hardware requirement)
- Minimum dimension size of 128 (too small = not worth overhead)
scripts/base_train.py:174-188:
Performance
Speed improvements (typical on H100):- ~2x faster matmul operations vs BF16
- ~30-40% overall training speedup (matmul is dominant but not 100% of time)
- Current leaderboard entry #2 uses FP8 to achieve 2.91 hour time-to-GPT-2
Evaluation in BF16
Evaluation runs in BF16 for consistency, even when training with FP8. The training script automatically disables FP8 during evaluation:Troubleshooting
”FP8 training requires CUDA”
FP8 requires CUDA GPUs. If you see this warning, your device type is not supported:Dimensions not divisible by 16
If too many layers are skipped, you may see:Compatibility
FP8 works with:- ✅ Multi-GPU training via
torchrun - ✅ Gradient accumulation
- ✅
torch.compile - ✅ Mixed precision (autocast)
- ❌ Non-CUDA devices (CPU, MPS)
- ❌ Pre-Hopper GPUs (A100, V100, etc.)
Example: Speedrun with FP8
The current speedrun configuration uses FP8 to achieve sub-3-hour GPT-2 training:Further Reading
nanochat/fp8.py- Full implementation with detailed commentsscripts/base_train.py:161-236- FP8 initialization and management- torchao documentation - Full torchao library (optional)
- NVIDIA FP8 formats - Official specification