Skip to main content

Overview

Candidate isolation is a critical design pattern in the Phoenix ranking transformer that ensures candidates cannot attend to each other during inference. This architectural choice guarantees that the score for a candidate depends only on the user context, not on which other candidates happen to be in the batch.

The Problem

In a standard transformer, every position can attend to every other position (subject to causal masking). For recommendation systems, this creates a problem:
  • If candidate A can attend to candidate B, then A’s score depends on B being in the batch
  • Scores become inconsistent across different batches
  • Caching becomes impossible since scores change based on batch composition
  • The model learns spurious correlations between co-occurring items rather than true user preferences

The Solution: Custom Attention Masking

Phoenix solves this by implementing a specialized attention mask in the make_recsys_attn_mask function that enforces candidate isolation:
phoenix/grok.py
def make_recsys_attn_mask(
    seq_len: int,
    candidate_start_offset: int,
    dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
    """Create attention mask for recommendation system inference.

    Creates a mask where:
    - Positions 0 to candidate_start_offset-1 (user+history): causal attention
    - Positions candidate_start_offset onwards (candidates): can attend to user+history
      and themselves (self-attention), but NOT to other candidates

    This ensures each candidate is scored independently based on user+history context.
    """
    # Start with causal mask for the full sequence
    causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))

    # Zero out candidate-to-candidate attention (bottom-right block)
    attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)

    # Add back self-attention for candidates (diagonal of the candidate block)
    candidate_indices = jnp.arange(candidate_start_offset, seq_len)
    attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)

    return attn_mask

Attention Mask Structure

The mask creates three distinct regions:
                  ATTENTION MASK VISUALIZATION

        Keys (what we attend TO)
        ─────────────────────────────────────────────▶

        │ User │    History (S positions)    │   Candidates (C positions)    │
   ┌────┼──────┼─────────────────────────────┼───────────────────────────────┤
   │    │      │                             │                               │
   │ U  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
   │    │      │                             │                               │
   ├────┼──────┼─────────────────────────────┼───────────────────────────────┤
Q  │    │      │                             │                               │
u  │ H  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
e  │ i  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
r  │ s  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
i  │ t  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✗    │
e  │    │      │                             │                               │
s  ├────┼──────┼─────────────────────────────┼───────────────────────────────┤
   │    │      │                             │  DIAGONAL ONLY (self-attend)  │
│  │ C  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✓   ✗   ✗   ✗   ✗   ✗   ✗    │
│  │ a  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✓   ✗   ✗   ✗   ✗   ✗    │
│  │ n  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✓   ✗   ✗   ✗   ✗    │
│  │ d  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✓   ✗   ✗   ✗    │
│  │ i  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✓   ✗   ✗    │
│  │ d  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✓   ✗    │
▼  │ s  │  ✓   │  ✓   ✓   ✓   ✓   ✓   ✓   ✓  │  ✗   ✗   ✗   ✗   ✗   ✗   ✓    │
   │    │      │                             │                               │
   └────┴──────┴─────────────────────────────┴───────────────────────────────┘

   ✓ = Can attend (1)          ✗ = Cannot attend (0)
Key Insight: The bottom-right block is diagonal-only, meaning each candidate can only attend to itself (self-attention) but not to other candidates.

Sequence Structure

The input sequence to the transformer is structured as:
[User] [History Item 1] [History Item 2] ... [History Item N] [Candidate 1] [Candidate 2] ... [Candidate M]
The candidate_start_offset parameter marks the boundary between history and candidates:
phoenix/grok.py
# In Transformer.__call__
if candidate_start_offset is not None:
    # Use recommendation system attention mask where candidates attend to
    # user+history and themselves, but not to other candidates
    attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
    mask = mask * attn_mask
else:
    # Standard causal mask for autoregressive sequence modelling
    causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(fprop_dtype)
    mask = mask * causal_mask

Benefits

1. Consistent Scoring

Each candidate receives the same score regardless of which other candidates are in the batch:
# Score for candidate A is the same in both batches
batch_1 = [user, history, A, B, C]
batch_2 = [user, history, A, D, E]
# score(A | user, history) is identical

2. Cacheability

Since scores are stable, you can:
  • Pre-compute scores for candidates offline
  • Cache scores for frequently shown candidates
  • Incrementally score new candidates without rescoring everything

3. Scalability

The model can score candidates in parallel without dependencies:
# Each candidate is scored independently
for candidate in candidates:
    score[candidate] = model(user, history, candidate)
# No cross-candidate information needed

4. Interpretability

Scores reflect true user-candidate relevance, not spurious batch effects:
  • Easier to debug model behavior
  • A/B tests are more reliable
  • Score distributions are stable over time

Implementation in Phoenix

The Phoenix ranking model applies candidate isolation automatically when candidate_start_offset is provided:
phoenix/recsys_model.py
def __call__(
    self,
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
) -> RecsysModelOutput:
    """Forward pass for ranking candidates."""
    embeddings, padding_mask, candidate_start_offset = self.build_inputs(
        batch, recsys_embeddings
    )

    # Transformer with candidate isolation
    model_output = self.model(
        embeddings,
        padding_mask,
        candidate_start_offset=candidate_start_offset,  # Enables isolation
    )

    # Extract only candidate outputs for scoring
    out_embeddings = model_output.embeddings
    candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]

    # Project to action logits
    unembeddings = self._get_unembedding()
    logits = jnp.dot(candidate_embeddings, unembeddings)

    return RecsysModelOutput(logits=logits)

Testing

The implementation includes comprehensive tests to verify the attention mask structure:
phoenix/test_recsys_model.py
def test_candidates_do_not_attend_to_other_candidates(self):
    """Test that candidates cannot attend to other candidates."""
    seq_len = 8
    candidate_start_offset = 5

    mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
    mask_2d = mask[0, 0]

    for query_pos in range(candidate_start_offset, seq_len):
        for key_pos in range(candidate_start_offset, seq_len):
            if query_pos != key_pos:
                assert mask_2d[query_pos, key_pos] == 0, (
                    f"Candidate at {query_pos} should NOT attend to candidate at {key_pos}"
                )
See the full test suite in /workspace/source/phoenix/test_recsys_model.py for all edge cases.

Hash-based Embeddings

Learn how Phoenix uses multiple hash functions for feature embeddings

Multi-action Prediction

Understand how the model predicts multiple engagement types simultaneously

Build docs developers (and LLMs) love