Skip to main content
The CausalSelfAttention class implements multi-head causal self-attention, the core mechanism that allows the model to attend to previous tokens in the sequence.

Class definition

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
Location: model.py:29-50

Parameters

config
GPTConfig
required
Configuration object containing model hyperparameters. The attention mechanism uses n_embd, n_head, dropout, bias, and block_size.

Components

c_attn
nn.Linear
Combined linear projection that computes query, key, and value for all heads in a single matrix multiplication. Maps from n_embd to 3 * n_embd.
c_proj
nn.Linear
Output projection that combines all attention heads back to n_embd dimensions.
attn_dropout
nn.Dropout
Dropout applied to attention weights (only used in manual attention implementation).
resid_dropout
nn.Dropout
Dropout applied to the final output before the residual connection.
n_head
int
Number of attention heads.
n_embd
int
Embedding dimension.
dropout
float
Dropout probability.
flash
bool
Whether Flash Attention is available (requires PyTorch >= 2.0).
bias
torch.Tensor | None
Causal attention mask (lower triangular matrix). Only registered when Flash Attention is not available.

Forward pass

The forward method implements causal self-attention with two possible execution paths:
def forward(self, x):
    B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

    # calculate query, key, values for all heads in batch and move head forward to be the batch dim
    q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

    # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
    if self.flash:
        # efficient attention using Flash Attention CUDA kernels
        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
    else:
        # manual implementation of attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

    # output projection
    y = self.resid_dropout(self.c_proj(y))
    return y
Location: model.py:52-76

Input

x
torch.Tensor
Input tensor of shape (B, T, C) where:
  • B = batch size
  • T = sequence length
  • C = embedding dimension (n_embd)

Output

output
torch.Tensor
Output tensor of shape (B, T, C) after applying causal self-attention and output projection.

Flash Attention vs manual implementation

The class automatically detects if Flash Attention is available and chooses the appropriate implementation:

Flash Attention (PyTorch >= 2.0)

y = torch.nn.functional.scaled_dot_product_attention(
    q, k, v,
    attn_mask=None,
    dropout_p=self.dropout if self.training else 0,
    is_causal=True
)
Advantages:
  • Significantly faster due to fused CUDA kernels
  • More memory efficient
  • Automatically handles causal masking with is_causal=True
Flash Attention makes “GPU go brrrrr” by using optimized CUDA kernels that fuse multiple operations and reduce memory bandwidth requirements.

Manual implementation (PyTorch < 2.0)

att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
Steps:
  1. Compute attention scores: Q @ K^T / sqrt(head_dim)
  2. Apply causal mask (prevent attending to future tokens)
  3. Apply softmax to get attention probabilities
  4. Apply dropout to attention weights
  5. Multiply by values to get output
If you see “WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0”, consider upgrading PyTorch for significantly better performance.

Multi-head attention

The implementation uses multi-head attention where the embedding dimension is split across multiple heads:
  • Each head operates on head_dim = n_embd // n_head dimensions
  • Heads process in parallel (reshaped to batch dimension)
  • Outputs are concatenated and projected back to n_embd

Shape transformations

Input: (B, T, n_embd)
  ↓ c_attn projection
(B, T, 3*n_embd) → split into Q, K, V → (B, T, n_embd) each
  ↓ reshape for multi-head
(B, n_head, T, head_dim)
  ↓ attention computation
(B, n_head, T, head_dim)
  ↓ transpose and reshape
(B, T, n_embd)
  ↓ c_proj output projection
(B, T, n_embd)

Causal masking

The attention is “causal” because each position can only attend to earlier positions in the sequence:
  • Position 0 can only attend to position 0
  • Position 1 can attend to positions 0-1
  • Position T can attend to positions 0-T
This is enforced by:
  • Flash Attention: is_causal=True parameter
  • Manual: Lower triangular mask that sets future positions to -inf before softmax
The causal mask is essential for autoregressive language modeling, ensuring the model can only use past context to predict future tokens.

Build docs developers (and LLMs) love