Overview
TheMHA class implements multi-head self-attention with several advanced features:
- Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)
- Optional 1D convolution for local context modeling
- Rotary position embeddings (RoPE)
- Integrated MLP for efficiency
- Optimized KV caching for inference
- Flash Attention support for faster computation
Class Definition
Parameters
Embedding dimension of the input.
Number of attention heads for queries.
Number of key-value heads for Multi-Query Attention (MQA) or Grouped-Query Attention (GQA). If
None, uses num_heads for standard multi-head attention. Must divide num_heads evenly.Dimension per attention head. If
None, uses embed_dim // num_heads. Allows for non-standard head dimensions.Dimension of integrated MLP (gated MLP with SiLU activation). If 0, no MLP is used. The dimension is rounded up to the nearest multiple of 256.
Whether to include bias terms in the QKV projection layer.
Whether to include bias term in the output projection layer.
Scale factor for attention scores before softmax. If
None, uses 1/sqrt(head_dim) as per the standard Transformer.Whether to use causal (masked) attention. Set to
True for autoregressive models.Layer index for KV caching during inference. Required when using inference mode.
Kernel size for 1D causal convolution applied to QKV before attention. If 0, no convolution is used. Adds local inductive bias.
Dimension for rotary position embeddings (RoPE). If 0, no rotary embeddings are used. Typically set to
head_dim or a fraction like head_dim // 2.Base value for computing rotary embeddings frequencies. Higher values result in slower position encoding decay.
Whether to use interleaved rotary embeddings format. If
False, uses the standard format.Device to place tensors on (e.g.,
torch.device('cuda')).Data type for tensors (e.g.,
torch.float16, torch.bfloat16).Methods
forward
Parameters
Input tensor of shape
(batch_size, seq_len, embed_dim).Parameters for inference mode. Should contain:
key_value_memory_dict: Dictionary mapping layer indices to KV cachesseqlen_offset: Current sequence position offsetmax_seqlen: Maximum sequence lengthlengths_per_sample: Per-sample sequence lengths (optional)
Returns
Output tensor of shape
(batch_size, seq_len, embed_dim).allocate_inference_cache
Parameters
Batch size for inference.
Maximum sequence length for inference.
Data type for cache tensors. If
None, uses the output projection weight dtype.Returns
Tensor of shape
(batch_size, max_seqlen, 2, num_heads_kv, head_dim) for storing key-value states.Tensor of shape
(batch_size, qkv_dim, d_conv) for convolution state, or None if d_conv=0.Usage Examples
Basic Multi-Head Attention
Multi-Query Attention (MQA)
Grouped-Query Attention (GQA)
With Rotary Embeddings
With Local Convolution
With Integrated MLP
Inference with KV Caching
Architecture Details
Multi-Query and Grouped-Query Attention
- Standard MHA: Each head has its own Q, K, V (memory intensive)
- MQA (
num_heads_kv=1): All query heads share one K, V pair (memory efficient) - GQA (
num_heads_kv=k): Query heads are grouped, each group shares K, V (balanced trade-off)
Integrated MLP
Whenmlp_dim > 0, the module includes a gated MLP:
Rotary Position Embeddings
RoPE encodes position information by rotating query and key representations:- Applied only to the first
rotary_emb_dimdimensions - Allows extrapolation to longer sequences than seen during training
- No learned parameters required
Performance Considerations
- Flash Attention: Automatically used when available for faster computation and lower memory usage
- KV Caching: Essential for efficient autoregressive generation
- MQA/GQA: Reduces KV cache size and memory bandwidth requirements
- Fused Operations: Convolution and rotary embeddings can be fused with attention computation
Notes
- Requires
flash_attnpackage for optimal performance - When using rotary embeddings,
flash_attnis required - Causal convolution requires
causal_conv1dpackage for best performance - The
layer_idxparameter must be set when using inference mode with caching
