Skip to main content

Overview

The Grok transformer provides the core neural architecture for both ranking and retrieval models. It implements a decoder-only transformer with grouped-query attention (GQA), RoPE positional encodings, and RMS normalization.

Classes

Transformer

A transformer stack implementing the decoder architecture.
num_q_heads
int
required
Number of query attention heads
num_kv_heads
int
required
Number of key/value attention heads (for grouped-query attention)
key_size
int
required
Dimension of attention keys and queries
widening_factor
float
required
Factor for expanding FFN hidden dimension
attn_output_multiplier
float
required
Multiplier applied to attention logits
num_layers
int
required
Number of transformer layers
name
Optional[str]
Optional name for the transformer

TransformerConfig

Configuration dataclass for the transformer architecture.
emb_size
int
required
Embedding dimension size
key_size
int
required
Dimension of attention keys
num_q_heads
int
required
Number of query heads
num_kv_heads
int
required
Number of key/value heads
num_layers
int
required
Number of transformer layers
widening_factor
float
default:"4.0"
FFN widening factor
attn_output_multiplier
float
default:"1.0"
Attention output multiplier
name
Optional[str]
Optional configuration name

DecoderLayer

A single transformer decoder layer.
num_q_heads
int
required
Number of query attention heads
num_kv_heads
int
required
Number of key/value attention heads
key_size
int
required
Dimension of attention keys
num_layers
int
required
Total number of layers in the stack
layer_index
Optional[int]
Index of this layer in the stack
widening_factor
float
default:"4.0"
FFN widening factor
name
Optional[str]
Optional layer name
attn_output_multiplier
float
default:"1.0"
Attention output multiplier

MultiHeadAttention

Multi-head attention with grouped-query attention and RoPE.
num_q_heads
int
required
Number of query heads
num_kv_heads
int
required
Number of key/value heads
key_size
int
required
Dimension of keys and queries
with_bias
bool
default:"True"
Whether to use bias in projections
value_size
Optional[int]
Dimension of values (defaults to key_size)
model_size
Optional[int]
Model dimension (defaults to key_size * num_q_heads)
attn_output_multiplier
float
default:"1.0"
Multiplier for attention logits
name
Optional[str]
Optional module name

RotaryEmbedding

Applies rotary positional embeddings (RoPE) as described in RoFormer.
dim
int
required
Dimensionality of the feature vectors (must be even)
name
Optional[str]
Optional module name
base_exponent
int
default:"10000"
Base exponent for computing frequencies

Named Tuples

TransformerOutput

Output of the transformer.
embeddings
jax.Array
Output embeddings from the transformer [B, T, D]

DecoderOutput

Output of a decoder layer.
embeddings
jax.Array
Output embeddings from the layer [B, T, D]

MHAOutput

Output of multi-head attention.
embeddings
jax.Array
Output embeddings from attention [B, T, D]

TrainingState

Container for training state.
params
hk.Params
Model parameters

Methods

Transformer.__call__

def __call__(
    embeddings: jax.Array,
    mask: jax.Array,
    candidate_start_offset: Optional[int] = None,
) -> TransformerOutput
Transforms input embedding sequences to output embedding sequences.
embeddings
jax.Array
required
Input embeddings [B, T, D]
mask
jax.Array
required
Padding mask [B, T], True for valid positions
candidate_start_offset
Optional[int]
If provided, positions >= this offset are treated as candidates that can only attend to positions before the offset (user+history) and themselves (self-attention), but not to other candidates. Used for recommendation system inference.
output
TransformerOutput
Transformer output containing embeddings [B, T, D]

RotaryEmbedding.__call__

def __call__(
    x: jax.Array,
    seq_dim: int,
    offset: jax.Array,
    const_position: Optional[int] = None,
    t: Optional[jax.Array] = None,
) -> jax.Array
Apply rotary embeddings to input tensor.
x
jax.Array
required
Input tensor to apply RoPE to
seq_dim
int
required
Dimension index corresponding to the sequence
offset
jax.Array
required
Position offset (scalar or per-batch element)
const_position
Optional[int]
Use constant position for all tokens if provided
t
Optional[jax.Array]
Custom position indices [B, T]
output
jax.Array
Tensor with rotary embeddings applied

Utility Functions

make_recsys_attn_mask

def make_recsys_attn_mask(
    seq_len: int,
    candidate_start_offset: int,
    dtype: jnp.dtype = jnp.float32,
) -> jax.Array
Create attention mask for recommendation system inference. Creates a mask where:
  • Positions 0 to candidate_start_offset-1 (user+history): causal attention
  • Positions candidate_start_offset onwards (candidates): can attend to user+history and themselves (self-attention), but NOT to other candidates
This ensures each candidate is scored independently based on user+history context.
seq_len
int
required
Total sequence length (user + history + candidates)
candidate_start_offset
int
required
Position where candidates start in the sequence
dtype
jnp.dtype
default:"jnp.float32"
Data type for the mask
mask
jax.Array
Attention mask [1, 1, seq_len, seq_len] where 1 means “can attend”

ffn_size

def ffn_size(emb_size: int, widening_factor: float) -> int
Calculate FFN hidden dimension from embedding size and widening factor.
emb_size
int
required
Embedding dimension
widening_factor
float
required
Widening factor for FFN
size
int
FFN hidden size (adjusted to be multiple of 8)

layer_norm

def layer_norm(x: jax.Array) -> jax.Array
Apply RMS normalization to input.
x
jax.Array
required
Input tensor
output
jax.Array
RMS-normalized tensor

rotate_half

def rotate_half(x: jax.Array) -> jax.Array
Obtain the rotated counterpart of each feature for RoPE.
x
jax.Array
required
Input tensor
output
jax.Array
Rotated tensor

Architecture Details

Attention Mechanism

  • Grouped-Query Attention (GQA): Reduces KV cache size by sharing key/value heads across multiple query heads
  • Rotary Positional Embeddings (RoPE): Encodes positional information directly into attention keys and queries
  • Attention Clipping: Logits are clipped using tanh to prevent overflow: 30.0 * tanh(logits / 30.0)

Normalization

  • Uses RMS normalization instead of LayerNorm for efficiency
  • Applied before attention and FFN blocks (pre-norm architecture)

Feed-Forward Network

  • Uses GeGLU activation: GELU(W1 * x) * (W2 * x)
  • Hidden dimension calculated as: int(widening_factor * emb_size) * 2 // 3
  • Adjusted to be a multiple of 8 for hardware efficiency

Special Attention Mask

The make_recsys_attn_mask function creates a specialized attention pattern for ranking:
User+History | Candidates
-------------|-----------
   Causal    | Self-only
  • User and history tokens use causal attention (can attend to previous tokens)
  • Candidate tokens can attend to all user+history tokens and themselves
  • Candidates cannot attend to other candidates (ensures independent scoring)

Build docs developers (and LLMs) love