Candidate isolation is a critical design pattern in the Phoenix ranking transformer that ensures candidates cannot attend to each other during inference. This architectural choice guarantees that the score for a candidate depends only on the user context, not on which other candidates happen to be in the batch.
In a standard transformer, every position can attend to every other position (subject to causal masking). For recommendation systems, this creates a problem:
If candidate A can attend to candidate B, then A’s score depends on B being in the batch
Scores become inconsistent across different batches
Caching becomes impossible since scores change based on batch composition
The model learns spurious correlations between co-occurring items rather than true user preferences
Phoenix solves this by implementing a specialized attention mask in the make_recsys_attn_mask function that enforces candidate isolation:
phoenix/grok.py
def make_recsys_attn_mask( seq_len: int, candidate_start_offset: int, dtype: jnp.dtype = jnp.float32,) -> jax.Array: """Create attention mask for recommendation system inference. Creates a mask where: - Positions 0 to candidate_start_offset-1 (user+history): causal attention - Positions candidate_start_offset onwards (candidates): can attend to user+history and themselves (self-attention), but NOT to other candidates This ensures each candidate is scored independently based on user+history context. """ # Start with causal mask for the full sequence causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype)) # Zero out candidate-to-candidate attention (bottom-right block) attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0) # Add back self-attention for candidates (diagonal of the candidate block) candidate_indices = jnp.arange(candidate_start_offset, seq_len) attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1) return attn_mask
The candidate_start_offset parameter marks the boundary between history and candidates:
phoenix/grok.py
# In Transformer.__call__if candidate_start_offset is not None: # Use recommendation system attention mask where candidates attend to # user+history and themselves, but not to other candidates attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype) mask = mask * attn_maskelse: # Standard causal mask for autoregressive sequence modelling causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(fprop_dtype) mask = mask * causal_mask
Each candidate receives the same score regardless of which other candidates are in the batch:
# Score for candidate A is the same in both batchesbatch_1 = [user, history, A, B, C]batch_2 = [user, history, A, D, E]# score(A | user, history) is identical
The model can score candidates in parallel without dependencies:
# Each candidate is scored independentlyfor candidate in candidates: score[candidate] = model(user, history, candidate)# No cross-candidate information needed
The implementation includes comprehensive tests to verify the attention mask structure:
phoenix/test_recsys_model.py
def test_candidates_do_not_attend_to_other_candidates(self): """Test that candidates cannot attend to other candidates.""" seq_len = 8 candidate_start_offset = 5 mask = make_recsys_attn_mask(seq_len, candidate_start_offset) mask_2d = mask[0, 0] for query_pos in range(candidate_start_offset, seq_len): for key_pos in range(candidate_start_offset, seq_len): if query_pos != key_pos: assert mask_2d[query_pos, key_pos] == 0, ( f"Candidate at {query_pos} should NOT attend to candidate at {key_pos}" )
See the full test suite in /workspace/source/phoenix/test_recsys_model.py for all edge cases.