Skip to main content

AttentionConfig

Configuration dataclass for multi-head attention hyperparameters.

Constructor

from modern_llm.models.attention import AttentionConfig

config = AttentionConfig(
    d_model=768,
    n_heads=12,
    use_rope=True,
    use_gqa=True,
    gqa_groups=4
)
d_model
int
required
Model dimension. Must be divisible by n_heads.
n_heads
int
required
Number of query heads in multi-head attention.
use_rope
bool
default:"True"
Enable Rotary Position Embeddings (Su et al., 2021) for better position encoding.
rope_theta
float
default:"10000.0"
Base frequency for RoPE. Higher values give slower position decay.
rope_scaling
Optional[float]
default:"None"
Scaling factor for RoPE frequencies to extend context length.
use_attention_sinks
bool
default:"False"
Enable attention sinks (Press et al., 2021) for long-context stability.
num_attention_sinks
int
default:"2"
Number of learnable sink tokens prepended to sequences. Must be > 0 when use_attention_sinks=True.
use_gqa
bool
default:"False"
Enable Grouped Query Attention (Ainslie et al., 2023) to reduce KV cache memory.
gqa_groups
Optional[int]
default:"None"
Number of KV head groups. Must divide n_heads when use_gqa=True.
dropout
float
default:"0.0"
Dropout probability applied to attention weights.
use_flash_attention
bool
default:"True"
Use PyTorch scaled_dot_product_attention (includes Flash Attention) for 2-4x speedup.

MultiHeadAttention

Multi-head self-attention implementing scaled dot-product attention with modern enhancements. Supports:
  • Rotary Position Embeddings (RoPE) for relative position encoding
  • Grouped Query Attention (GQA) to reduce memory usage
  • Attention sinks for long-context stability
  • Flash Attention via PyTorch SDPA

Mathematical formulation

Attention(Q, K, V) = softmax((Q K^T + mask) / sqrt(d_k)) V
With RoPE, Q and K are rotated by position-dependent angles before computing attention scores.

Constructor

from modern_llm.models.attention import MultiHeadAttention, AttentionConfig

config = AttentionConfig(
    d_model=768,
    n_heads=12,
    use_rope=True,
    dropout=0.1
)
attn = MultiHeadAttention(config)
config
AttentionConfig
required
Attention configuration containing all hyperparameters.

Attributes

config
AttentionConfig
The configuration object passed during initialization.
head_dim
int
Dimension per attention head: d_model / n_heads.
num_q_heads
int
Number of query heads (equals n_heads from config).
num_kv_heads
int
Number of key/value heads. Equals gqa_groups when using GQA, otherwise n_heads.
scale
float
Attention scaling factor: 1 / sqrt(head_dim).
q_proj
nn.Linear
Query projection: d_model -> d_model.
k_proj
nn.Linear
Key projection: d_model -> (num_kv_heads * head_dim).
v_proj
nn.Linear
Value projection: d_model -> (num_kv_heads * head_dim).
out_proj
nn.Linear
Output projection: d_model -> d_model.
sink_states
Optional[nn.Parameter]
Learnable attention sink tokens of shape (num_attention_sinks, d_model) when enabled.

forward

def forward(
    self,
    hidden_states: Tensor,
    attention_mask: Optional[Tensor] = None
) -> Tensor
Compute multi-head self-attention.
hidden_states
Tensor
required
Input tensor of shape (batch, seq_len, d_model).
attention_mask
Optional[Tensor]
Additive attention bias of shape (batch, 1, seq_len, seq_len). Use 0 for valid positions and large negative values (e.g., -inf) for masked positions.

Returns

output
Tensor
Attention output of shape (batch, seq_len, d_model).

Example

import torch
from modern_llm.models.attention import MultiHeadAttention, AttentionConfig

# Standard multi-head attention
config = AttentionConfig(
    d_model=768,
    n_heads=12,
    use_rope=True,
    dropout=0.1
)
attn = MultiHeadAttention(config)

hidden_states = torch.randn(2, 128, 768)
output = attn(hidden_states)
print(output.shape)  # torch.Size([2, 128, 768])

# Grouped query attention (reduce memory)
gqa_config = AttentionConfig(
    d_model=768,
    n_heads=12,
    use_gqa=True,
    gqa_groups=4  # 12 Q heads share 4 KV heads
)
gqa_attn = MultiHeadAttention(gqa_config)

# With attention sinks for long context
sink_config = AttentionConfig(
    d_model=768,
    n_heads=12,
    use_attention_sinks=True,
    num_attention_sinks=2
)
sink_attn = MultiHeadAttention(sink_config)

Grouped query attention

GQA reduces KV cache memory by sharing key/value heads across query heads. With n_heads=12 and gqa_groups=4, the 12 query heads share only 4 KV heads, reducing memory by 3x.
config = AttentionConfig(
    d_model=768,
    n_heads=12,
    use_gqa=True,
    gqa_groups=4  # 3x memory reduction
)

Attention sinks

Attention sinks are learnable tokens prepended to sequences that all tokens can attend to, improving stability for long contexts.
config = AttentionConfig(
    d_model=768,
    n_heads=12,
    use_attention_sinks=True,
    num_attention_sinks=2
)

Flash attention

Flash Attention is enabled by default via PyTorch’s scaled_dot_product_attention, providing 2-4x speedup and lower memory usage. It’s automatically used when available and attention sinks are disabled.
config = AttentionConfig(
    d_model=768,
    n_heads=12,
    use_flash_attention=True  # Default
)

Complexity

O(seq_len² · d_model / n_heads) per layer due to the quadratic attention matrix.

Build docs developers (and LLMs) love