Skip to main content

Overview

Attention sinks are learnable tokens prepended to the input sequence that every token can attend to. They serve as stable attention targets, improving model performance on long sequences and enabling streaming generation with fixed-size KV cache.
Motivation: Press et al. (2021) - Train Short, Test Long: Attention with Linear Biases Enables Input Length ExtrapolationXiao et al. (2023) - Efficient Streaming Language Models with Attention SinksAttention sinks help prevent attention collapse in autoregressive generation and enable infinite sequence generation with bounded memory.

The problem: Attention collapse

In standard transformer decoders, long autoregressive generation can suffer from attention collapse:
During generation, the model must attend to previous tokens. For very long sequences:
  1. Early tokens lose relevance: The first few tokens become semantically irrelevant to the current generation
  2. Attention must sum to 1: Softmax forces attention weights to sum to 1.0
  3. Nowhere to put attention: Model has no good place to “dump” unused attention mass
  4. Attention collapses: All attention concentrates on a few recent tokens, losing access to earlier context
Example:
Sequence: "Once upon a time, in a land far away... [4000 more tokens]"

When generating token 4096:
- "Once" is no longer relevant
- "upon" is no longer relevant  
- But attention weights must sum to 1.0!
- Where should the model put this "unused" attention?

Mathematical formulation

Standard attention

Without sinks, attention operates on sequence tokens only:
Q, K, V = project(hidden_states)  # Shape: (batch, n_heads, seq_len, head_dim)
scores = QK^T / sqrt(d_k)          # Shape: (batch, n_heads, seq_len, seq_len)
attn = softmax(scores)              # Row sums to 1.0
output = attn V

With attention sinks

Prepend learnable sink tokens to K and V:
# Learnable sink embeddings
S = [s₁, s₂, ..., sₙ]  # n sink tokens, shape: (n_sinks, d_model)

# Project sinks to K, V
K_sinks = project_K(S)  # Shape: (batch, n_heads, n_sinks, head_dim)
V_sinks = project_V(S)  # Shape: (batch, n_heads, n_sinks, head_dim)

# Concatenate with sequence K, V
K_full = concat([K_sinks, K_seq], dim=2)  # Shape: (batch, n_heads, n_sinks + seq_len, head_dim)
V_full = concat([V_sinks, V_seq], dim=2)

# Attention can now attend to sinks or sequence
scores = Q K_full^T / sqrt(d_k)  # Shape: (batch, n_heads, seq_len, n_sinks + seq_len)
attn = softmax(scores)            # Can distribute attention to sinks!
output = attn V_full
Key insight: By expanding the attention matrix from seq_len × seq_len to seq_len × (n_sinks + seq_len), we give the model more “degrees of freedom” for where to place attention.

Implementation

Modern LLM implements attention sinks with careful position encoding:
attention.py:95-100
# During model initialization
if config.use_attention_sinks:
    # Learnable sink embeddings
    self.sink_states = nn.Parameter(
        torch.randn(config.num_attention_sinks, config.d_model) * 0.02
    )
else:
    self.register_parameter("sink_states", None)
Design choices:
  • Initialized with small random noise (std=0.02)
  • Registered as parameters (learned during training)
  • Same dimension as hidden states (d_model)
  • Typical count: 2-4 sink tokens

Streaming generation

Attention sinks enable infinite sequence generation with bounded memory:
With attention sinks, you can evict old tokens while keeping sinks:
class StreamingCache:
    def __init__(self, max_length=2048, num_sinks=2):
        self.max_length = max_length
        self.num_sinks = num_sinks
        self.k_cache = None
        self.v_cache = None
    
    def update(self, new_k, new_v):
        if self.k_cache is None:
            # First tokens: includes sinks
            self.k_cache = new_k
            self.v_cache = new_v
        else:
            # Append new tokens
            self.k_cache = torch.cat([self.k_cache, new_k], dim=2)
            self.v_cache = torch.cat([self.v_cache, new_v], dim=2)
            
            # If exceeding max length, evict old tokens (but keep sinks!)
            if self.k_cache.size(2) > self.max_length:
                # Keep: [sinks] + [recent tokens]
                keep_length = self.max_length - self.num_sinks
                self.k_cache = torch.cat([
                    self.k_cache[:, :, :self.num_sinks],  # Keep sinks
                    self.k_cache[:, :, -keep_length:]     # Keep recent
                ], dim=2)
                self.v_cache = torch.cat([
                    self.v_cache[:, :, :self.num_sinks],
                    self.v_cache[:, :, -keep_length:]
                ], dim=2)
        
        return self.k_cache, self.v_cache
Memory usage: Fixed at 2 × max_length × n_heads × head_dim, regardless of generation length!

Configuration

Number of sinks

Typical values: 2-4 sink tokens
config = ModernLLMConfig(
    use_attention_sinks=True,
    num_attention_sinks=2,  # Default
)
Num sinksParameters addedUse case
1d_modelMinimal (may be insufficient)
22 × d_modelStandard (recommended)
44 × d_modelHigh capacity
88 × d_modelResearch/debugging
Empirical results (Xiao et al., 2023):
  • 1 sink: Moderate improvement
  • 2 sinks: Significant improvement (recommended)
  • 4 sinks: Slightly better than 2
  • 8 sinks: Marginal gains over 4
Recommendation: Start with 2, increase to 4 if working with very long contexts (>8K tokens).
For a model with d_model=768 and n_layers=12:
Parameters per sink per layer:
- Sink embedding: 768
- No additional weights (uses existing projections)

Total for 2 sinks, 12 layers:
2 × 768 = 1,536 parameters (negligible!)
Memory during forward pass:
KV cache increase per layer:
2 × n_heads × head_dim × 2 (K and V)
= 2 × 12 × 64 × 2 = 3,072 floats per layer
= 12.3 KB per layer (fp32)
= 147.5 KB for 12 layers (negligible!)
Attention sinks add minimal overhead!

Training considerations

When training a new model with attention sinks:
config = ModernLLMConfig(
    use_attention_sinks=True,
    num_attention_sinks=2,
    
    # Standard training settings
    max_seq_len=2048,
    # ...
)

model = ModernDecoderLM(config)
# Train normally - sinks learn automatically!
What sinks learn:
  • Where to dump unused attention
  • Global context patterns (e.g., sentence boundaries)
  • Task-specific information (e.g., summarization: store length constraints)

Empirical results

Long-context generation

From Xiao et al. (2023) - Streaming with Attention Sinks:
ModelContextWithout sinksWith 4 sinks
LLaMA-7B4K45.3 PPL15.8 PPL
LLaMA-7B8Kinfinity16.2 PPL
LLaMA-7B16Kinfinity17.1 PPL
LLaMA-13B4K38.7 PPL14.2 PPL
LLaMA-13B8Kinfinity14.9 PPL
Attention sinks enable up to 10× longer context than training length with stable performance.

Memory efficiency

Generation lengthMemory (no sinks)Memory (with sinks)Reduction
4K tokens4K cache2K cache50%
8K tokens8K cache2K cache75%
16K tokens16K cache2K cache87.5%
32K tokens32K cache2K cache93.75%
Breakthrough: Memory usage is constant regardless of generation length!

Comparison to alternatives

MethodMemoryQualityImplementation
Full attentionO(n²)BestSimple
Sliding windowO(w)Poor (loses old context)Simple
Sparse attentionO(n√n)ModerateComplex
Attention sinksO(w)ExcellentModerate
Recurrent memoryO(m)GoodComplex
Where:
  • n = sequence length
  • w = window size
  • m = memory size
Attention sinks provide the best trade-off between memory efficiency, generation quality, and implementation complexity.

Advanced topics

Allocate different numbers of sinks per layer:
# Early layers: fewer sinks (local patterns)
# Later layers: more sinks (global reasoning)

def get_num_sinks(layer_idx, n_layers):
    if layer_idx < n_layers // 3:
        return 1  # Early: 1 sink
    elif layer_idx < 2 * n_layers // 3:
        return 2  # Middle: 2 sinks  
    else:
        return 4  # Late: 4 sinks
Research shows this can slightly improve efficiency with minimal quality loss.
Use different sinks for different context levels:
# Sink 0: Token-level (recent few tokens)
# Sink 1: Sentence-level (current sentence)
# Sink 2: Paragraph-level (current paragraph)  
# Sink 3: Document-level (overall topic)
Train sinks with auxiliary losses to enforce this hierarchy.
Fine-tune sinks for specific tasks:Summarization:
# Sink learns to store desired summary length
# Sink learns key topics to include
Question answering:
# Sink stores the question embedding
# Sink stores expected answer type
Code generation:
# Sink stores language/framework context
# Sink stores current function signature
Visualize what models learn to store in sinks:
# Compute average attention to each sink
attn_to_sinks = attn_weights[:, :, :, :num_sinks].mean(dim=(0, 1, 2))

print(f"Attention to sink 0: {attn_to_sinks[0]:.3f}")
print(f"Attention to sink 1: {attn_to_sinks[1]:.3f}")

# Plot attention heatmap
plt.imshow(attn_weights[0, 0, :, :num_sinks])
plt.xlabel("Sink token")
plt.ylabel("Query position")
Common patterns:
  • First sink: Receives stable ~5-10% attention
  • Second sink: Receives variable attention (task-dependent)
  • Later tokens: Reduced “attention collapse” to first token

Common issues

Symptom: Attention to sinks is near-zeroCauses:
  1. Training sequences too short (model doesn’t need sinks)
  2. Learning rate too low (sinks don’t learn)
  3. Initialization too small (sinks are ignored)
Solutions:
# Train on longer sequences
config.max_seq_len = 4096  # from 1024

# Slightly larger initialization
self.sink_states = nn.Parameter(
    torch.randn(n_sinks, d_model) * 0.05  # from 0.02
)

# Separate learning rate for sinks
optimizer = AdamW([
    {'params': model.base_params, 'lr': 3e-4},
    {'params': model.sink_params, 'lr': 1e-3},  # Higher!
])
Issue: RoPE positions conflict with sink positionsSolution: Ensure correct position offsets (already implemented):
# Sinks get positions 0, 1, ..., n_sinks - 1
sink_k = apply_rope(sink_k, n_sinks, offset=0)

# Sequence keys continue from n_sinks
k = apply_rope(k, seq_len, offset=0)

# Queries see correct relative positions
q = apply_rope(q, seq_len, offset=n_sinks)
Issue: KV cache grows despite sinksCause: Not evicting old tokens correctlySolution: Always keep sinks at the start:
# WRONG: Evicting includes sinks
cache = cache[:, :, -max_length:]

# CORRECT: Keep sinks + recent tokens
cache = torch.cat([
    cache[:, :, :num_sinks],           # Keep sinks
    cache[:, :, -(max_length - num_sinks):]  # Keep recent
], dim=2)

References

Train Short, Test Long

Press et al., 2021 - Motivation for attention sinks

Efficient Streaming Language Models with Attention Sinks

Xiao et al., 2023 - Explicit attention sinks implementation and analysis

LM-Infinite: Simple On-the-Fly Length Generalization

Han et al., 2023 - Related approach to infinite generation

Landmark Attention

Mohtashami & Jaggi, 2023 - Similar concept with “landmark” tokens

See also

Architecture overview

Learn about the full model architecture

RoPE

Position encoding used with attention sinks

Multi-head attention

Full attention implementation details

Configuration

Configure attention sink parameters

Build docs developers (and LLMs) love