Overview
Attention sinks are learnable tokens prepended to the input sequence that every token can attend to. They serve as stable attention targets, improving model performance on long sequences and enabling streaming generation with fixed-size KV cache.
The problem: Attention collapse
In standard transformer decoders, long autoregressive generation can suffer from attention collapse :
During generation, the model must attend to previous tokens. For very long sequences:
Early tokens lose relevance : The first few tokens become semantically irrelevant to the current generation
Attention must sum to 1 : Softmax forces attention weights to sum to 1.0
Nowhere to put attention : Model has no good place to “dump” unused attention mass
Attention collapses : All attention concentrates on a few recent tokens, losing access to earlier context
Example :Sequence: "Once upon a time, in a land far away... [4000 more tokens]"
When generating token 4096:
- "Once" is no longer relevant
- "upon" is no longer relevant
- But attention weights must sum to 1.0!
- Where should the model put this "unused" attention?
Without attention sinks:
Instability : Attention weights become unstable
Poor generation : Model loses coherence after 1K-2K tokens
Cache size : Must keep entire KV cache in memory
No streaming : Cannot generate infinite sequences
Empirical observation (Xiao et al., 2023):Perplexity vs. sequence length (without sinks):
Length 1K: 15.2 PPL ✓
Length 2K: 18.7 PPL ⚠️
Length 4K: 45.3 PPL ❌
Length 8K: infinity ❌❌
Attention sinks provide designated attention targets :
Attention dumping ground : Model can attend to sinks when current context is sufficient
Stable across sequence : Sink tokens don’t change as sequence grows
Learnable : Model learns what patterns to store in sinks during training
Always accessible : Sinks stay in KV cache even when evicting old tokens
With attention sinks :Sequence: "[SINK1] [SINK2] Once upon a time, in a land far away... [4000 more tokens]"
When generating token 4096:
- Can attend to SINK1 or SINK2 instead of irrelevant early tokens
- Attention distribution remains stable
- Model maintains coherence
Standard attention
Without sinks, attention operates on sequence tokens only:
Q, K, V = project(hidden_states) # Shape: (batch, n_heads, seq_len, head_dim)
scores = QK^T / sqrt(d_k) # Shape: (batch, n_heads, seq_len, seq_len)
attn = softmax(scores) # Row sums to 1.0
output = attn V
With attention sinks
Prepend learnable sink tokens to K and V:
# Learnable sink embeddings
S = [s₁, s₂, ..., sₙ] # n sink tokens, shape: (n_sinks, d_model)
# Project sinks to K, V
K_sinks = project_K(S) # Shape: (batch, n_heads, n_sinks, head_dim)
V_sinks = project_V(S) # Shape: (batch, n_heads, n_sinks, head_dim)
# Concatenate with sequence K, V
K_full = concat([K_sinks, K_seq], dim=2) # Shape: (batch, n_heads, n_sinks + seq_len, head_dim)
V_full = concat([V_sinks, V_seq], dim=2)
# Attention can now attend to sinks or sequence
scores = Q K_full^T / sqrt(d_k) # Shape: (batch, n_heads, seq_len, n_sinks + seq_len)
attn = softmax(scores) # Can distribute attention to sinks!
output = attn V_full
Key insight: By expanding the attention matrix from seq_len × seq_len to seq_len × (n_sinks + seq_len), we give the model more “degrees of freedom” for where to place attention.
Implementation
Modern LLM implements attention sinks with careful position encoding:
Sink initialization
Forward pass
RoPE integration
# During model initialization
if config.use_attention_sinks:
# Learnable sink embeddings
self .sink_states = nn.Parameter(
torch.randn(config.num_attention_sinks, config.d_model) * 0.02
)
else :
self .register_parameter( "sink_states" , None )
Design choices :
Initialized with small random noise (std=0.02)
Registered as parameters (learned during training)
Same dimension as hidden states (d_model)
Typical count: 2-4 sink tokens
def forward ( self , hidden_states : Tensor, attention_mask : Optional[Tensor] = None ) -> Tensor:
batch_size, seq_len, _ = hidden_states.shape
# Project sequence to Q, K, V
q = self ._shape_q( self .q_proj(hidden_states))
k = self ._shape_kv( self .k_proj(hidden_states))
v = self ._shape_kv( self .v_proj(hidden_states))
# Apply RoPE to queries and keys
if self .config.use_rope:
q = self ._apply_rope(q, seq_len, offset = self .config.num_attention_sinks)
k = self ._apply_rope(k, seq_len)
key_length = seq_len
# Add attention sinks
if self .config.use_attention_sinks and self .sink_states is not None :
# Expand sinks for batch
sink_states = self .sink_states.unsqueeze( 0 ).expand(batch_size, - 1 , - 1 )
# Project sinks to K, V
sink_k = self ._shape_kv( self .k_proj(sink_states))
sink_v = self ._shape_kv( self .v_proj(sink_states))
# Apply RoPE to sink keys (positions 0, 1, ...)
if self .config.use_rope:
sink_k = self ._apply_rope(sink_k, self .config.num_attention_sinks, offset = 0 )
# Concatenate sinks with sequence K, V
k = torch.cat([sink_k, k], dim = 2 )
v = torch.cat([sink_v, v], dim = 2 )
key_length += self .config.num_attention_sinks
# Update attention mask to allow attending to sinks
if attention_mask is not None :
sink_bias = torch.zeros(
batch_size, 1 ,
attention_mask.size( - 2 ),
self .config.num_attention_sinks,
device = attention_mask.device,
dtype = attention_mask.dtype,
)
attention_mask = torch.cat([sink_bias, attention_mask], dim =- 1 )
# Standard attention computation
attn_scores = torch.matmul(q, k.transpose( - 2 , - 1 )) * self .scale
if attention_mask is not None :
attn_scores = attn_scores + attention_mask
attn_probs = torch.softmax(attn_scores, dim =- 1 )
context = torch.matmul(attn_probs, v)
return self .out_proj(context.transpose( 1 , 2 ).contiguous().view(batch_size, seq_len, - 1 ))
Careful position encoding is critical: # Sink positions: 0, 1, 2, ..., n_sinks - 1
sink_k = self ._apply_rope(sink_k, n_sinks, offset = 0 )
# Sequence positions: n_sinks, n_sinks + 1, ..., n_sinks + seq_len - 1
q = self ._apply_rope(q, seq_len, offset = n_sinks)
k = self ._apply_rope(k, seq_len, offset = 0 ) # Keys still start at 0
Why this works :
Sink keys get positions 0, 1, …
Sequence keys get positions starting after sinks
Queries are offset so they see correct relative positions
RoPE’s relative encoding handles the rest automatically
Streaming generation
Attention sinks enable infinite sequence generation with bounded memory:
With attention sinks, you can evict old tokens while keeping sinks: class StreamingCache :
def __init__ ( self , max_length = 2048 , num_sinks = 2 ):
self .max_length = max_length
self .num_sinks = num_sinks
self .k_cache = None
self .v_cache = None
def update ( self , new_k , new_v ):
if self .k_cache is None :
# First tokens: includes sinks
self .k_cache = new_k
self .v_cache = new_v
else :
# Append new tokens
self .k_cache = torch.cat([ self .k_cache, new_k], dim = 2 )
self .v_cache = torch.cat([ self .v_cache, new_v], dim = 2 )
# If exceeding max length, evict old tokens (but keep sinks!)
if self .k_cache.size( 2 ) > self .max_length:
# Keep: [sinks] + [recent tokens]
keep_length = self .max_length - self .num_sinks
self .k_cache = torch.cat([
self .k_cache[:, :, : self .num_sinks], # Keep sinks
self .k_cache[:, :, - keep_length:] # Keep recent
], dim = 2 )
self .v_cache = torch.cat([
self .v_cache[:, :, : self .num_sinks],
self .v_cache[:, :, - keep_length:]
], dim = 2 )
return self .k_cache, self .v_cache
Memory usage : Fixed at 2 × max_length × n_heads × head_dim, regardless of generation length!def generate_streaming ( model , prompt_ids , max_new_tokens = 10000 ):
cache = StreamingCache( max_length = 2048 , num_sinks = 2 )
current_ids = prompt_ids
generated = []
for _ in range (max_new_tokens):
# Forward pass with KV cache
outputs = model(
current_ids,
past_key_values = cache,
use_cache = True
)
# Sample next token
next_token = sample(outputs.logits[:, - 1 , :])
generated.append(next_token)
# Update for next iteration
current_ids = next_token.unsqueeze( 0 )
# Cache automatically evicts old tokens, keeps sinks
# Memory usage remains constant!
return torch.cat(generated)
Key benefit : Can generate arbitrary length with bounded memory (2048 tokens cached, regardless of generation length).From Xiao et al. (2023): Cache size Without sinks With 4 sinks Full cache 1K tokens 45.3 PPL ❌ 15.8 PPL ✓ 15.2 PPL 2K tokens 28.1 PPL ⚠️ 15.5 PPL ✓ 15.2 PPL 4K tokens 18.2 PPL ⚠️ 15.3 PPL ✓ 15.2 PPL
With sinks, streaming cache matches full cache performance!
Configuration
Number of sinks
Typical values: 2-4 sink tokens
config = ModernLLMConfig(
use_attention_sinks = True ,
num_attention_sinks = 2 , # Default
)
Num sinks Parameters added Use case 1 d_modelMinimal (may be insufficient) 2 2 × d_modelStandard (recommended) 4 4 × d_modelHigh capacity 8 8 × d_modelResearch/debugging
Empirical results (Xiao et al., 2023):
1 sink: Moderate improvement
2 sinks: Significant improvement (recommended)
4 sinks: Slightly better than 2
8 sinks: Marginal gains over 4
Recommendation : Start with 2, increase to 4 if working with very long contexts (>8K tokens).
For a model with d_model=768 and n_layers=12: Parameters per sink per layer:
- Sink embedding: 768
- No additional weights (uses existing projections)
Total for 2 sinks, 12 layers:
2 × 768 = 1,536 parameters (negligible!)
Memory during forward pass :KV cache increase per layer:
2 × n_heads × head_dim × 2 (K and V)
= 2 × 12 × 64 × 2 = 3,072 floats per layer
= 12.3 KB per layer (fp32)
= 147.5 KB for 12 layers (negligible!)
Attention sinks add minimal overhead!
Training considerations
When training a new model with attention sinks: config = ModernLLMConfig(
use_attention_sinks = True ,
num_attention_sinks = 2 ,
# Standard training settings
max_seq_len = 2048 ,
# ...
)
model = ModernDecoderLM(config)
# Train normally - sinks learn automatically!
What sinks learn :
Where to dump unused attention
Global context patterns (e.g., sentence boundaries)
Task-specific information (e.g., summarization: store length constraints)
Adding sinks to a pretrained model: # Load pretrained model without sinks
base_model = ModernDecoderLM.from_pretrained( "path/to/model" )
# Create new config with sinks
new_config = base_model.config
new_config.use_attention_sinks = True
new_config.num_attention_sinks = 2
# Initialize new model
model_with_sinks = ModernDecoderLM(new_config)
# Copy weights (sinks initialized randomly)
model_with_sinks.load_state_dict(
base_model.state_dict(),
strict = False # Allow missing sink parameters
)
# Fine-tune briefly to learn sink usage
# Recommended: ~1-5% of original pretraining steps
Fine-tuning tips :
Use lower learning rate (10× smaller than pretraining)
Train on long sequences (at least 2× your target length)
Monitor attention patterns to sinks
Gradually increase sequence length during training: # Phase 1: Short sequences (sinks learn basic patterns)
train( max_seq_len = 512 , steps = 10000 )
# Phase 2: Medium sequences (sinks learn to handle context)
train( max_seq_len = 1024 , steps = 10000 )
# Phase 3: Long sequences (sinks learn eviction strategies)
train( max_seq_len = 2048 , steps = 10000 )
# Phase 4: Very long (optional, for streaming)
train( max_seq_len = 4096 , steps = 5000 )
This helps sinks learn progressively more sophisticated patterns.
Empirical results
Long-context generation
From Xiao et al. (2023) - Streaming with Attention Sinks:
Model Context Without sinks With 4 sinks LLaMA-7B 4K 45.3 PPL 15.8 PPL LLaMA-7B 8K infinity 16.2 PPL LLaMA-7B 16K infinity 17.1 PPL LLaMA-13B 4K 38.7 PPL 14.2 PPL LLaMA-13B 8K infinity 14.9 PPL
Attention sinks enable up to 10× longer context than training length with stable performance.
Memory efficiency
Generation length Memory (no sinks) Memory (with sinks) Reduction 4K tokens 4K cache 2K cache 50% 8K tokens 8K cache 2K cache 75% 16K tokens 16K cache 2K cache 87.5% 32K tokens 32K cache 2K cache 93.75%
Breakthrough : Memory usage is constant regardless of generation length!
Comparison to alternatives
Method Memory Quality Implementation Full attention O(n²) Best Simple Sliding window O(w) Poor (loses old context) Simple Sparse attention O(n√n) Moderate Complex Attention sinks O(w) Excellent Moderate Recurrent memory O(m) Good Complex
Where:
n = sequence length
w = window size
m = memory size
Attention sinks provide the best trade-off between memory efficiency, generation quality, and implementation complexity.
Advanced topics
Allocate different numbers of sinks per layer: # Early layers: fewer sinks (local patterns)
# Later layers: more sinks (global reasoning)
def get_num_sinks ( layer_idx , n_layers ):
if layer_idx < n_layers // 3 :
return 1 # Early: 1 sink
elif layer_idx < 2 * n_layers // 3 :
return 2 # Middle: 2 sinks
else :
return 4 # Late: 4 sinks
Research shows this can slightly improve efficiency with minimal quality loss.
Use different sinks for different context levels: # Sink 0: Token-level (recent few tokens)
# Sink 1: Sentence-level (current sentence)
# Sink 2: Paragraph-level (current paragraph)
# Sink 3: Document-level (overall topic)
Train sinks with auxiliary losses to enforce this hierarchy.
Fine-tune sinks for specific tasks: Summarization :# Sink learns to store desired summary length
# Sink learns key topics to include
Question answering :# Sink stores the question embedding
# Sink stores expected answer type
Code generation :# Sink stores language/framework context
# Sink stores current function signature
Attention pattern analysis
Visualize what models learn to store in sinks: # Compute average attention to each sink
attn_to_sinks = attn_weights[:, :, :, :num_sinks].mean( dim = ( 0 , 1 , 2 ))
print ( f "Attention to sink 0: { attn_to_sinks[ 0 ] :.3f} " )
print ( f "Attention to sink 1: { attn_to_sinks[ 1 ] :.3f} " )
# Plot attention heatmap
plt.imshow(attn_weights[ 0 , 0 , :, :num_sinks])
plt.xlabel( "Sink token" )
plt.ylabel( "Query position" )
Common patterns:
First sink: Receives stable ~5-10% attention
Second sink: Receives variable attention (task-dependent)
Later tokens: Reduced “attention collapse” to first token
Common issues
Symptom : Attention to sinks is near-zeroCauses :
Training sequences too short (model doesn’t need sinks)
Learning rate too low (sinks don’t learn)
Initialization too small (sinks are ignored)
Solutions :# Train on longer sequences
config.max_seq_len = 4096 # from 1024
# Slightly larger initialization
self .sink_states = nn.Parameter(
torch.randn(n_sinks, d_model) * 0.05 # from 0.02
)
# Separate learning rate for sinks
optimizer = AdamW([
{ 'params' : model.base_params, 'lr' : 3e-4 },
{ 'params' : model.sink_params, 'lr' : 1e-3 }, # Higher!
])
Position encoding conflicts
Issue : RoPE positions conflict with sink positionsSolution : Ensure correct position offsets (already implemented):# Sinks get positions 0, 1, ..., n_sinks - 1
sink_k = apply_rope(sink_k, n_sinks, offset = 0 )
# Sequence keys continue from n_sinks
k = apply_rope(k, seq_len, offset = 0 )
# Queries see correct relative positions
q = apply_rope(q, seq_len, offset = n_sinks)
Memory leaks in streaming
Issue : KV cache grows despite sinksCause : Not evicting old tokens correctlySolution : Always keep sinks at the start:# WRONG: Evicting includes sinks
cache = cache[:, :, - max_length:]
# CORRECT: Keep sinks + recent tokens
cache = torch.cat([
cache[:, :, :num_sinks], # Keep sinks
cache[:, :, - (max_length - num_sinks):] # Keep recent
], dim = 2 )
References
Train Short, Test Long Press et al., 2021 - Motivation for attention sinks
Efficient Streaming Language Models with Attention Sinks Xiao et al., 2023 - Explicit attention sinks implementation and analysis
LM-Infinite: Simple On-the-Fly Length Generalization Han et al., 2023 - Related approach to infinite generation
Landmark Attention Mohtashami & Jaggi, 2023 - Similar concept with “landmark” tokens
See also
Architecture overview Learn about the full model architecture
RoPE Position encoding used with attention sinks
Multi-head attention Full attention implementation details
Configuration Configure attention sink parameters