Attention Mechanisms
import { tensor } from "deepbox/ndarray";
import { MultiheadAttention } from "deepbox/nn";
// MultiheadAttention(embedDim, numHeads)
// embedDim must be divisible by numHeads
const mha = new MultiheadAttention(8, 2);
console.log("MultiheadAttention(embedDim=8, numHeads=2)");
console.log("Each head has dimension 8/2 = 4");
// Input: (batch, seqLen, embedDim)
const seqData = tensor([[
[1, 0, 1, 0, 1, 0, 1, 0],
[0, 1, 0, 1, 0, 1, 0, 1],
[1, 1, 0, 0, 1, 1, 0, 0],
]]);
console.log(`\nInput shape: [${seqData.shape.join(", ")}] (batch=1, seq=3, embed=8)`);
// Self-attention: query = key = value = input
const attnOut = mha.forward(seqData, seqData, seqData);
const attnShape = attnOut instanceof GradTensor ? attnOut.tensor.shape : attnOut.shape;
console.log(`Output shape: [${attnShape.join(", ")}]`);
console.log("Each position attends to all other positions");
MultiheadAttention(embedDim=8, numHeads=2)
Each head has dimension 8/2 = 4
Input shape: [1, 3, 8] (batch=1, seq=3, embed=8)
Output shape: [1, 3, 8]
Each position attends to all other positions
import { TransformerEncoderLayer } from "deepbox/nn";
// TransformerEncoderLayer combines:
// MultiheadAttention + FeedForward + LayerNorm + Dropout
const encoderLayer = new TransformerEncoderLayer(8, 2, 16);
console.log("\nTransformerEncoderLayer(dModel=8, nHead=2, dimFeedforward=16)");
console.log(`Input shape: [${seqData.shape.join(", ")}]`);
const encoderOut = encoderLayer.forward(seqData);
const encShape = encoderOut instanceof GradTensor ? encoderOut.tensor.shape : encoderOut.shape;
console.log(`Output shape: [${encShape.join(", ")}]`);
console.log("Full transformer encoder block with residual connections");
TransformerEncoderLayer(dModel=8, nHead=2, dimFeedforward=16)
Input shape: [1, 3, 8]
Output shape: [1, 3, 8]
Full transformer encoder block with residual connections
const mhaParams = Array.from(mha.parameters()).length;
const encParams = Array.from(encoderLayer.parameters()).length;
console.log("\nParameter Counts:");
console.log(`MultiheadAttention params: ${mhaParams}`);
console.log(`TransformerEncoderLayer params: ${encParams}`);
console.log("Encoder layer includes attention + feedforward + normalization");
How Attention Works
- Query-Key-Value: Input is projected into three representations
- Attention scores: Computed by comparing queries with keys
- Softmax: Scores normalized to attention weights
- Weighted sum: Values weighted by attention scores
- Multi-head: Multiple attention patterns learned in parallel
Advantages Over RNNs
- Parallelization: All positions processed simultaneously
- Long-range dependencies: Direct connections between any positions
- Interpretability: Attention weights show which positions interact
- Scalability: Efficient on modern hardware
Use Cases
- Natural language processing (BERT, GPT)
- Machine translation
- Text summarization
- Code generation
- Image processing (Vision Transformers)
- Multi-modal learning
Next Steps
Layer Normalization
Essential for training deep transformers
Positional Encoding
Add position information to sequences