AttentionConfig
Configuration dataclass for multi-head attention hyperparameters.Constructor
Model dimension. Must be divisible by n_heads.
Number of query heads in multi-head attention.
Enable Rotary Position Embeddings (Su et al., 2021) for better position encoding.
Base frequency for RoPE. Higher values give slower position decay.
Scaling factor for RoPE frequencies to extend context length.
Enable attention sinks (Press et al., 2021) for long-context stability.
Number of learnable sink tokens prepended to sequences. Must be > 0 when use_attention_sinks=True.
Enable Grouped Query Attention (Ainslie et al., 2023) to reduce KV cache memory.
Number of KV head groups. Must divide n_heads when use_gqa=True.
Dropout probability applied to attention weights.
Use PyTorch scaled_dot_product_attention (includes Flash Attention) for 2-4x speedup.
MultiHeadAttention
Multi-head self-attention implementing scaled dot-product attention with modern enhancements. Supports:- Rotary Position Embeddings (RoPE) for relative position encoding
- Grouped Query Attention (GQA) to reduce memory usage
- Attention sinks for long-context stability
- Flash Attention via PyTorch SDPA
Mathematical formulation
Constructor
Attention configuration containing all hyperparameters.
Attributes
The configuration object passed during initialization.
Dimension per attention head: d_model / n_heads.
Number of query heads (equals n_heads from config).
Number of key/value heads. Equals gqa_groups when using GQA, otherwise n_heads.
Attention scaling factor: 1 / sqrt(head_dim).
Query projection: d_model -> d_model.
Key projection: d_model -> (num_kv_heads * head_dim).
Value projection: d_model -> (num_kv_heads * head_dim).
Output projection: d_model -> d_model.
Learnable attention sink tokens of shape (num_attention_sinks, d_model) when enabled.
forward
Input tensor of shape (batch, seq_len, d_model).
Additive attention bias of shape (batch, 1, seq_len, seq_len). Use 0 for valid positions and large negative values (e.g., -inf) for masked positions.
Returns
Attention output of shape (batch, seq_len, d_model).
Example
Grouped query attention
GQA reduces KV cache memory by sharing key/value heads across query heads. Withn_heads=12 and gqa_groups=4, the 12 query heads share only 4 KV heads, reducing memory by 3x.
Attention sinks
Attention sinks are learnable tokens prepended to sequences that all tokens can attend to, improving stability for long contexts.Flash attention
Flash Attention is enabled by default via PyTorch’sscaled_dot_product_attention, providing 2-4x speedup and lower memory usage. It’s automatically used when available and attention sinks are disabled.