Overview
The Grok transformer provides the core neural architecture for both ranking and retrieval models. It implements a decoder-only transformer with grouped-query attention (GQA), RoPE positional encodings, and RMS normalization.Classes
Transformer
A transformer stack implementing the decoder architecture.Number of query attention heads
Number of key/value attention heads (for grouped-query attention)
Dimension of attention keys and queries
Factor for expanding FFN hidden dimension
Multiplier applied to attention logits
Number of transformer layers
Optional name for the transformer
TransformerConfig
Configuration dataclass for the transformer architecture.Embedding dimension size
Dimension of attention keys
Number of query heads
Number of key/value heads
Number of transformer layers
FFN widening factor
Attention output multiplier
Optional configuration name
DecoderLayer
A single transformer decoder layer.Number of query attention heads
Number of key/value attention heads
Dimension of attention keys
Total number of layers in the stack
Index of this layer in the stack
FFN widening factor
Optional layer name
Attention output multiplier
MultiHeadAttention
Multi-head attention with grouped-query attention and RoPE.Number of query heads
Number of key/value heads
Dimension of keys and queries
Whether to use bias in projections
Dimension of values (defaults to key_size)
Model dimension (defaults to key_size * num_q_heads)
Multiplier for attention logits
Optional module name
RotaryEmbedding
Applies rotary positional embeddings (RoPE) as described in RoFormer.Dimensionality of the feature vectors (must be even)
Optional module name
Base exponent for computing frequencies
Named Tuples
TransformerOutput
Output of the transformer.Output embeddings from the transformer [B, T, D]
DecoderOutput
Output of a decoder layer.Output embeddings from the layer [B, T, D]
MHAOutput
Output of multi-head attention.Output embeddings from attention [B, T, D]
TrainingState
Container for training state.Model parameters
Methods
Transformer.__call__
Input embeddings [B, T, D]
Padding mask [B, T], True for valid positions
If provided, positions >= this offset are treated as candidates that can only attend to positions before the offset (user+history) and themselves (self-attention), but not to other candidates. Used for recommendation system inference.
Transformer output containing embeddings [B, T, D]
RotaryEmbedding.__call__
Input tensor to apply RoPE to
Dimension index corresponding to the sequence
Position offset (scalar or per-batch element)
Use constant position for all tokens if provided
Custom position indices [B, T]
Tensor with rotary embeddings applied
Utility Functions
make_recsys_attn_mask
- Positions 0 to
candidate_start_offset-1(user+history): causal attention - Positions
candidate_start_offsetonwards (candidates): can attend to user+history and themselves (self-attention), but NOT to other candidates
Total sequence length (user + history + candidates)
Position where candidates start in the sequence
Data type for the mask
Attention mask [1, 1, seq_len, seq_len] where 1 means “can attend”
ffn_size
Embedding dimension
Widening factor for FFN
FFN hidden size (adjusted to be multiple of 8)
layer_norm
Input tensor
RMS-normalized tensor
rotate_half
Input tensor
Rotated tensor
Architecture Details
Attention Mechanism
- Grouped-Query Attention (GQA): Reduces KV cache size by sharing key/value heads across multiple query heads
- Rotary Positional Embeddings (RoPE): Encodes positional information directly into attention keys and queries
- Attention Clipping: Logits are clipped using
tanhto prevent overflow:30.0 * tanh(logits / 30.0)
Normalization
- Uses RMS normalization instead of LayerNorm for efficiency
- Applied before attention and FFN blocks (pre-norm architecture)
Feed-Forward Network
- Uses GeGLU activation:
GELU(W1 * x) * (W2 * x) - Hidden dimension calculated as:
int(widening_factor * emb_size) * 2 // 3 - Adjusted to be a multiple of 8 for hardware efficiency
Special Attention Mask
Themake_recsys_attn_mask function creates a specialized attention pattern for ranking:
- User and history tokens use causal attention (can attend to previous tokens)
- Candidate tokens can attend to all user+history tokens and themselves
- Candidates cannot attend to other candidates (ensures independent scoring)