Skip to main content
This page provides detailed technical information about the Phoenix transformer architecture, including the critical attention masking mechanism that enables candidate isolation.

Transformer Configuration

Phoenix uses a standard decoder-only transformer with several key architectural choices:
@dataclass
class TransformerConfig:
    emb_size: int                    # Model dimension (e.g., 2048)
    key_size: int                    # Attention head dimension (e.g., 128)
    num_q_heads: int                 # Number of query heads (e.g., 16)
    num_kv_heads: int                # Number of key/value heads (e.g., 8)
    num_layers: int                  # Number of transformer layers (e.g., 24)
    widening_factor: float = 4.0     # FFN expansion factor
    attn_output_multiplier: float = 1.0  # Attention scaling
Source: grok.py:88-100
Grouped Query Attention: Phoenix uses num_q_heads > num_kv_heads, which reduces memory usage and increases inference speed while maintaining model quality.

Attention Masking for Candidate Isolation

The Problem

In standard transformer inference for sequences, each position can attend to all previous positions (causal masking). For recommendation ranking, this would allow candidates to see each other:
Standard Causal Mask (WRONG for ranking):
     [User][History...][Cand1][Cand2][Cand3]
User   ✓
Hist   ✓      ✓
Cand1  ✓      ✓       ✓
Cand2  ✓      ✓       ✓      ✓        ← Problem: Cand2 sees Cand1
Cand3  ✓      ✓       ✓      ✓      ✓  ← Problem: Cand3 sees Cand1 & Cand2
This creates batch composition effects: the score for Cand2 depends on whether Cand1 is in the batch.

The Solution: Custom Attention Mask

Phoenix implements a specialized attention mask using make_recsys_attn_mask:
def make_recsys_attn_mask(
    seq_len: int,
    candidate_start_offset: int,
    dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
    """Create attention mask for recommendation system inference.

    Creates a mask where:
    - Positions 0 to candidate_start_offset-1 (user+history): causal attention
    - Positions candidate_start_offset onwards (candidates): can attend to user+history
      and themselves (self-attention), but NOT to other candidates

    This ensures each candidate is scored independently based on user+history context.

    Args:
        seq_len: Total sequence length (user + history + candidates)
        candidate_start_offset: Position where candidates start in the sequence
        dtype: Data type for the mask

    Returns:
        Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend"
    """
    # Start with causal mask for the full sequence
    causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))

    # Zero out candidate-to-candidate attention (bottom-right block)
    attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)

    # Add back self-attention for candidates (diagonal of the candidate block)
    candidate_indices = jnp.arange(candidate_start_offset, seq_len)
    attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)

    return attn_mask
Source: grok.py:39-71

Attention Mask Visualization

Here’s the full attention pattern for Phoenix ranking:
                     ATTENTION MASK VISUALIZATION

          Keys (what we attend TO)
          ─────────────────────────────────────────────▶

          │ User │    History (S positions)    │   Candidates (C positions)    │
     ┌────┼──────┼─────────────────────────────┼───────────────────────────────┤
     │    │      │                             │                               │
     │ U  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
     │    │      │                             │                               │
     ├────┼──────┼─────────────────────────────┼───────────────────────────────┤
  Q  │    │      │                             │                               │
  u  │ H  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
  e  │ i  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
  r  │ s  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
  i  │ t  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
  e  │    │      │                             │                               │
  s  ├────┼──────┼─────────────────────────────┼───────────────────────────────┤
     │    │      │                             │  DIAGONAL ONLY (self-attend)  │
  │  │ C  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✓   ✗   ✗   ✗   ✗   ✗   ✗    │
  │  │ a  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✓   ✗   ✗   ✗   ✗   ✗    │
  │  │ n  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✓   ✗   ✗   ✗   ✗    │
  │  │ d  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✓   ✗   ✗   ✗    │
  │  │ i  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✓   ✗   ✗    │
  │  │ d  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✓   ✗    │
  ▼  │ s  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✓    │
     │    │      │                             │                               │
     └────┴──────┴─────────────────────────────┴───────────────────────────────┘

     ✓ = Can attend (1)          ✗ = Cannot attend (0)

     Legend:
     ├─ User + History: Full bidirectional attention among themselves
     ├─ Candidates → User/History: Candidates CAN attend to user and history  
     └─ Candidates → Candidates: Candidates CANNOT attend to each other (only self)
Key Insight: The bottom-right block (candidate-to-candidate attention) is diagonal only. Each candidate can attend to:
  • ✓ User embedding
  • ✓ All history positions
  • ✓ Itself (self-attention)
  • ✗ Other candidates

Transformer Forward Pass

The transformer applies the custom mask during attention computation:
class Transformer(hk.Module):
    def __call__(
        self,
        embeddings: jax.Array,  # [B, T, D]
        mask: jax.Array,  # [B, T]
        candidate_start_offset: Optional[int] = None,
    ) -> TransformerOutput:
        """Transforms input embedding sequences to output embedding sequences.

        Args:
            embeddings: Input embeddings of shape [B, T, D]
            mask: Padding mask of shape [B, T], True for valid positions
            candidate_start_offset: If provided, positions >= this offset are treated as
                candidates that can only attend to positions before the offset (user+history)
                and themselves (self-attention), but not to other candidates.
                Used for recommendation system inference.

        Returns:
            TransformerOutput containing the output embeddings.
        """
        fprop_dtype = embeddings.dtype
        _, seq_len, _ = embeddings.shape
        padding_mask = mask.copy()
        mask = mask[:, None, None, :]  # [B, H=1, T'=1, T]

        if candidate_start_offset is not None:
            # Use recommendation system attention mask where candidates attend to
            # user+history and themselves, but not to other candidates
            attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
            mask = mask * attn_mask
        else:
            # Standard causal mask for autoregressive sequence modelling
            causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(fprop_dtype)
            mask = mask * causal_mask

        h = embeddings

        # Apply transformer layers
        for i in range(self.num_layers):
            decoder_output = DecoderLayer(...)(h, mask, padding_mask)
            h = decoder_output.embeddings

        return TransformerOutput(embeddings=h)
Source: grok.py:516-586
The candidate_start_offset parameter controls the masking behavior:
  • None: Standard causal mask (for language modeling or retrieval user tower)
  • int: Custom candidate isolation mask (for ranking)

Transformer Layer Components

1. Multi-Head Attention

Phoenix uses Rotary Position Embeddings (RoPE) for position-aware attention:
class MultiHeadAttention(hk.Module):
    def __call__(
        self,
        query: jax.Array,
        key: jax.Array,
        value: jax.Array,
        mask: jax.Array,
    ) -> MHAOutput:
        # Project to query, key, value
        query_heads = projection(query, self.key_size, self.num_q_heads)
        key_heads = projection(key, self.key_size, self.num_kv_heads)
        value_heads = projection(value, self.value_size, self.num_kv_heads)

        # Apply rotary embeddings
        rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
        key_heads = rotate(key_heads, seq_dim=1, offset=0)
        query_heads = rotate(query_heads, seq_dim=1, offset=0)

        # Grouped query attention
        query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))

        # Compute attention scores
        attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads)
        attn_logits *= self.attn_output_multiplier
        
        # Apply tanh capping for stability
        max_attn_val = jnp.array(30.0)
        attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)

        # Apply mask
        if mask is not None:
            attn_logits = jnp.where(mask, attn_logits, -1e30)
        
        # Softmax and weighted sum
        attn_weights = jax.nn.softmax(attn_logits)
        attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)

        return MHAOutput(final_projection(attn))
Source: grok.py:264-363
RoPE encodes position information by rotating the query and key vectors:
x = x * cos(phase) + rotate_half(x) * sin(phase)
This provides relative position information without adding explicit position embeddings.Source: grok.py:205-261
Phoenix applies tanh capping to attention logits for numerical stability:
attn_logits = 30.0 * tanh(attn_logits / 30.0)
This prevents extremely large logits that can cause softmax overflow.
Instead of creating separate K/V heads for each Q head, Phoenix uses fewer K/V heads and reuses them:
  • Reduces memory usage by ~50% (with 2:1 ratio)
  • Faster inference due to fewer key-value computations
  • Minimal quality degradation

2. Feed-Forward Network

Phoenix uses SwiGLU activation in the FFN:
class DenseBlock(hk.Module):
    def __call__(self, inputs: jax.Array) -> jax.Array:
        # SwiGLU: element-wise product of two projections
        h_v = Linear(ffn_size(model_size, widening_factor))(inputs)
        h_w1 = jax.nn.gelu(
            Linear(ffn_size(model_size, widening_factor))(inputs)
        )
        h_dense = Linear(model_size)(h_w1 * h_v)
        return h_dense
Source: grok.py:414-440 The FFN size is computed to be a multiple of 8 for hardware efficiency:
def ffn_size(emb_size, widening_factor):
    _ffn_size = int(widening_factor * emb_size) * 2 // 3
    _ffn_size = _ffn_size + (8 - _ffn_size) % 8  # Round up to multiple of 8
    return _ffn_size
Source: grok.py:32-36

3. Layer Normalization

Phoenix uses RMS Normalization (Root Mean Square Layer Normalization):
class RMSNorm(hk.RMSNorm):
    def __call__(self, inputs: jax.Array):
        # Compute RMS
        mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
        normed_inputs = inputs * jax.lax.rsqrt(mean_squared + eps)
        
        # Scale
        return scale * normed_inputs
Source: grok.py:162-194
RMSNorm is simpler and faster than LayerNorm because it doesn’t compute the mean (only the RMS). This provides similar normalization benefits with lower computational cost.

4. Decoder Layer Structure

Each transformer layer follows this pattern:
class DecoderLayer(hk.Module):
    def __call__(self, inputs: jax.Array, mask: jax.Array, padding_mask) -> DecoderOutput:
        h = inputs
        
        # Pre-norm + Multi-head attention + Post-norm + Residual
        attn_output = MHABlock(...)(layer_norm(h), mask)
        h_attn = layer_norm(attn_output.embeddings)
        h = h + h_attn
        
        # Pre-norm + FFN + Post-norm + Residual
        h_dense = DenseBlock(...)(layer_norm(h))
        h_dense = layer_norm(h_dense)
        h = h + h_dense
        
        return DecoderOutput(embeddings=h)
Source: grok.py:444-497
Phoenix applies layer norm before and after each sub-layer, which differs from standard “Pre-LN” transformers. This “sandwich” normalization improves training stability.

JAX and Haiku Implementation

Phoenix is implemented using:
  • JAX: For automatic differentiation and XLA compilation
  • Haiku: For neural network modules and parameter management

Key Benefits

JAX

  • Automatic differentiation
  • XLA compilation for TPU/GPU
  • Functional transformations (vmap, pmap)
  • Efficient numerical computing

Haiku

  • Clean module abstraction
  • Explicit parameter management
  • JAX-native design
  • Transformer utilities (RMSNorm, etc.)

Example Usage

import haiku as hk
import jax
import jax.numpy as jnp

# Define model as a Haiku function
def forward(batch, embeddings):
    config = PhoenixModelConfig(...)
    model = config.make()
    return model(batch, embeddings)

# Transform to pure functions
forward_fn = hk.transform(forward)

# Initialize parameters
rng = jax.random.PRNGKey(42)
params = forward_fn.init(rng, batch, embeddings)

# Run inference
output = forward_fn.apply(params, rng, batch, embeddings)

Performance Optimizations

Phoenix uses bfloat16 for forward pass computation:
fprop_dtype: Any = jnp.bfloat16
This reduces memory usage and increases throughput on modern accelerators (TPUs, A100s) while maintaining numerical stability.
Using fewer K/V heads than Q heads reduces the memory bandwidth bottleneck during inference, leading to faster decoding.
FFN dimensions are rounded to multiples of 8 for optimal hardware utilization on GPUs and TPUs.
JAX compiles the model to XLA, which performs aggressive optimizations like:
  • Operator fusion
  • Memory layout optimization
  • Automatic parallelization

Comparison: Retrieval vs Ranking Masking

Retrieval User Tower

# Standard causal mask
candidate_start_offset = None
User and history tokens use standard bidirectional causal attention. No candidates in this model.

Ranking Model

# Custom isolation mask
candidate_start_offset = 1 + history_len
Candidates isolated from each other but can attend to user + history.

grok.py

Core transformer implementation
  • Transformer class
  • make_recsys_attn_mask function
  • Attention and FFN layers

recsys_model.py

Phoenix ranking model
  • PhoenixModel class
  • Input embedding construction
  • Multi-action output projection

recsys_retrieval_model.py

Phoenix retrieval model
  • PhoenixRetrievalModel class
  • User and candidate towers
  • Similarity search

Next Steps

Overview

Return to Phoenix system overview

Ranking Model

Learn about the ranking model implementation

Retrieval Model

Explore the two-tower retrieval architecture

Build docs developers (and LLMs) love