Deep dive into the transformer implementation, attention masking, and candidate isolation mechanism in Phoenix
This page provides detailed technical information about the Phoenix transformer architecture, including the critical attention masking mechanism that enables candidate isolation.
Phoenix uses a standard decoder-only transformer with several key architectural choices:
@dataclassclass TransformerConfig: emb_size: int # Model dimension (e.g., 2048) key_size: int # Attention head dimension (e.g., 128) num_q_heads: int # Number of query heads (e.g., 16) num_kv_heads: int # Number of key/value heads (e.g., 8) num_layers: int # Number of transformer layers (e.g., 24) widening_factor: float = 4.0 # FFN expansion factor attn_output_multiplier: float = 1.0 # Attention scaling
Source: grok.py:88-100
Grouped Query Attention: Phoenix uses num_q_heads > num_kv_heads, which reduces memory usage and increases inference speed while maintaining model quality.
In standard transformer inference for sequences, each position can attend to all previous positions (causal masking). For recommendation ranking, this would allow candidates to see each other:
Standard Causal Mask (WRONG for ranking): [User][History...][Cand1][Cand2][Cand3]User ✓Hist ✓ ✓Cand1 ✓ ✓ ✓Cand2 ✓ ✓ ✓ ✓ ← Problem: Cand2 sees Cand1Cand3 ✓ ✓ ✓ ✓ ✓ ← Problem: Cand3 sees Cand1 & Cand2
This creates batch composition effects: the score for Cand2 depends on whether Cand1 is in the batch.
Phoenix implements a specialized attention mask using make_recsys_attn_mask:
def make_recsys_attn_mask( seq_len: int, candidate_start_offset: int, dtype: jnp.dtype = jnp.float32,) -> jax.Array: """Create attention mask for recommendation system inference. Creates a mask where: - Positions 0 to candidate_start_offset-1 (user+history): causal attention - Positions candidate_start_offset onwards (candidates): can attend to user+history and themselves (self-attention), but NOT to other candidates This ensures each candidate is scored independently based on user+history context. Args: seq_len: Total sequence length (user + history + candidates) candidate_start_offset: Position where candidates start in the sequence dtype: Data type for the mask Returns: Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend" """ # Start with causal mask for the full sequence causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype)) # Zero out candidate-to-candidate attention (bottom-right block) attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0) # Add back self-attention for candidates (diagonal of the candidate block) candidate_indices = jnp.arange(candidate_start_offset, seq_len) attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1) return attn_mask
The transformer applies the custom mask during attention computation:
class Transformer(hk.Module): def __call__( self, embeddings: jax.Array, # [B, T, D] mask: jax.Array, # [B, T] candidate_start_offset: Optional[int] = None, ) -> TransformerOutput: """Transforms input embedding sequences to output embedding sequences. Args: embeddings: Input embeddings of shape [B, T, D] mask: Padding mask of shape [B, T], True for valid positions candidate_start_offset: If provided, positions >= this offset are treated as candidates that can only attend to positions before the offset (user+history) and themselves (self-attention), but not to other candidates. Used for recommendation system inference. Returns: TransformerOutput containing the output embeddings. """ fprop_dtype = embeddings.dtype _, seq_len, _ = embeddings.shape padding_mask = mask.copy() mask = mask[:, None, None, :] # [B, H=1, T'=1, T] if candidate_start_offset is not None: # Use recommendation system attention mask where candidates attend to # user+history and themselves, but not to other candidates attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype) mask = mask * attn_mask else: # Standard causal mask for autoregressive sequence modelling causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(fprop_dtype) mask = mask * causal_mask h = embeddings # Apply transformer layers for i in range(self.num_layers): decoder_output = DecoderLayer(...)(h, mask, padding_mask) h = decoder_output.embeddings return TransformerOutput(embeddings=h)
Source: grok.py:516-586
The candidate_start_offset parameter controls the masking behavior:
None: Standard causal mask (for language modeling or retrieval user tower)
RMSNorm is simpler and faster than LayerNorm because it doesn’t compute the mean (only the RMS). This provides similar normalization benefits with lower computational cost.
class DecoderLayer(hk.Module): def __call__(self, inputs: jax.Array, mask: jax.Array, padding_mask) -> DecoderOutput: h = inputs # Pre-norm + Multi-head attention + Post-norm + Residual attn_output = MHABlock(...)(layer_norm(h), mask) h_attn = layer_norm(attn_output.embeddings) h = h + h_attn # Pre-norm + FFN + Post-norm + Residual h_dense = DenseBlock(...)(layer_norm(h)) h_dense = layer_norm(h_dense) h = h + h_dense return DecoderOutput(embeddings=h)
Source: grok.py:444-497
Phoenix applies layer norm before and after each sub-layer, which differs from standard “Pre-LN” transformers. This “sandwich” normalization improves training stability.