CausalSelfAttention class implements multi-head causal self-attention, the core mechanism that allows the model to attend to previous tokens in the sequence.
Class definition
model.py:29-50
Parameters
Configuration object containing model hyperparameters. The attention mechanism uses
n_embd, n_head, dropout, bias, and block_size.Components
Combined linear projection that computes query, key, and value for all heads in a single matrix multiplication. Maps from
n_embd to 3 * n_embd.Output projection that combines all attention heads back to
n_embd dimensions.Dropout applied to attention weights (only used in manual attention implementation).
Dropout applied to the final output before the residual connection.
Number of attention heads.
Embedding dimension.
Dropout probability.
Whether Flash Attention is available (requires PyTorch >= 2.0).
Causal attention mask (lower triangular matrix). Only registered when Flash Attention is not available.
Forward pass
The forward method implements causal self-attention with two possible execution paths:model.py:52-76
Input
Input tensor of shape
(B, T, C) where:B= batch sizeT= sequence lengthC= embedding dimension (n_embd)
Output
Output tensor of shape
(B, T, C) after applying causal self-attention and output projection.Flash Attention vs manual implementation
The class automatically detects if Flash Attention is available and chooses the appropriate implementation:Flash Attention (PyTorch >= 2.0)
- Significantly faster due to fused CUDA kernels
- More memory efficient
- Automatically handles causal masking with
is_causal=True
Flash Attention makes “GPU go brrrrr” by using optimized CUDA kernels that fuse multiple operations and reduce memory bandwidth requirements.
Manual implementation (PyTorch < 2.0)
- Compute attention scores: Q @ K^T / sqrt(head_dim)
- Apply causal mask (prevent attending to future tokens)
- Apply softmax to get attention probabilities
- Apply dropout to attention weights
- Multiply by values to get output
Multi-head attention
The implementation uses multi-head attention where the embedding dimension is split across multiple heads:- Each head operates on
head_dim = n_embd // n_headdimensions - Heads process in parallel (reshaped to batch dimension)
- Outputs are concatenated and projected back to
n_embd
Shape transformations
Causal masking
The attention is “causal” because each position can only attend to earlier positions in the sequence:- Position 0 can only attend to position 0
- Position 1 can attend to positions 0-1
- Position T can attend to positions 0-T
- Flash Attention:
is_causal=Trueparameter - Manual: Lower triangular mask that sets future positions to
-infbefore softmax
The causal mask is essential for autoregressive language modeling, ensuring the model can only use past context to predict future tokens.