Skip to main content
The Phoenix ranking model is a transformer-based neural network that scores candidates retrieved by the retrieval stage. It predicts multiple engagement actions (likes, reposts, replies, etc.) for each candidate while ensuring that scores are independent of other items in the batch.

Model Architecture

                          PHOENIX RANKING MODEL
┌────────────────────────────────────────────────────────────────────────────┐
│                                                                            │
│                              OUTPUT LOGITS                                 │
│                        [B, num_candidates, num_actions]                    │
│                                    │                                       │
│                                    │ Unembedding                           │
│                                    │ Projection                            │
│                                    │                                       │
│                    ┌───────────────┴───────────────┐                       │
│                    │                               │                       │
│                    │    Extract Candidate Outputs  │                       │
│                    │    (positions after history)  │                       │
│                    │                               │                       │
│                    └───────────────┬───────────────┘                       │
│                                    │                                       │
│                    ┌───────────────┴───────────────┐                       │
│                    │                               │                       │
│                    │         Transformer           │                       │
│                    │     (with special masking)    │                       │
│                    │                               │                       │
│                    │   Candidates CANNOT attend    │                       │
│                    │   to each other               │                       │
│                    │                               │                       │
│                    └───────────────┬───────────────┘                       │
│                                    │                                       │
│    ┌───────────────────────────────┼───────────────────────────────┐       │
│    │                               │                               │       │
│    ▼                               ▼                               ▼       │
│ ┌──────────┐              ┌─────────────────┐              ┌────────────┐  │
│ │   User   │              │     History     │              │ Candidates │  │
│ │Embedding │              │   Embeddings    │              │ Embeddings │  │
│ │  [B, 1]  │              │    [B, S, D]    │              │  [B, C, D] │  │
│ │          │              │                 │              │            │  │
│ │ User     │              │ Posts + Authors │              │ Posts +    │  │
│ │ Hashes   │              │ + Actions +     │              │ Authors +  │  │
│ │          │              │ Product Surface │              │ Product    │  │
│ └──────────┘              └─────────────────┘              │ Surface    │  │
│                                                            └────────────┘  │
│                                                                            │
└────────────────────────────────────────────────────────────────────────────┘

Key Design: Candidate Isolation

Critical Design Requirement: The score for a candidate must not depend on which other candidates are in the batch. This ensures consistent scoring regardless of batch composition.
Candidate isolation is achieved through custom attention masking that prevents candidates from attending to each other while still allowing them to attend to user and history context.

Why Candidate Isolation Matters

Without isolation, the model could:
  • Score candidates differently based on which other items are in the batch
  • Create inconsistent rankings across different batch compositions
  • Make A/B testing and experimentation unreliable
  • Cause unexpected behavior when batch sizes or sampling strategies change
With isolation, each candidate receives a deterministic score based solely on user context.

Forward Pass Implementation

Here’s how the ranking model processes inputs:
def __call__(
    self,
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
) -> RecsysModelOutput:
    """Forward pass for ranking candidates.

    Args:
        batch: RecsysBatch containing hashes, actions, product surfaces
        recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings

    Returns:
        RecsysModelOutput containing logits for each candidate. 
        Shape = [B, num_candidates, num_actions]
    """
    # Build input embeddings: [B, 1 + S + C, D]
    embeddings, padding_mask, candidate_start_offset = self.build_inputs(
        batch, recsys_embeddings
    )

    # Pass through transformer with candidate isolation
    model_output = self.model(
        embeddings,
        padding_mask,
        candidate_start_offset=candidate_start_offset,  # Enables special masking
    )

    out_embeddings = model_output.embeddings
    out_embeddings = layer_norm(out_embeddings)

    # Extract only candidate outputs (discard user + history)
    candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]

    # Project to action logits
    unembeddings = self._get_unembedding()  # [D, num_actions]
    logits = jnp.dot(candidate_embeddings, unembeddings)  # [B, C, num_actions]

    return RecsysModelOutput(logits=logits)
Source: recsys_model.py:439-474

Input Embedding Construction

The model combines multiple feature types into input embeddings:

1. User Embedding

def block_user_reduce(
    user_hashes: jnp.ndarray,  # [B, num_user_hashes]
    user_embeddings: jnp.ndarray,  # [B, num_user_hashes, D]
    num_user_hashes: int,
    emb_size: int,
) -> Tuple[jax.Array, jax.Array]:
    """Combine multiple user hash embeddings into a single user representation.

    Returns:
        user_embedding: [B, 1, D] - combined user embedding
        user_padding_mask: [B, 1] - True where user is valid
    """
    B, D = user_embeddings.shape[0], emb_size

    # Flatten multiple hash embeddings
    user_embedding = user_embeddings.reshape((B, 1, num_user_hashes * D))

    # Project to embedding dimension
    user_embedding = jnp.dot(user_embedding, proj_mat_1)  # [B, 1, D]

    # Create padding mask (hash 0 is reserved for padding)
    user_padding_mask = (user_hashes[:, 0] != 0).reshape(B, 1)

    return user_embedding, user_padding_mask
Source: recsys_model.py:79-119

2. History Embeddings

History embeddings combine:
  • Post hashes (multiple per position)
  • Author hashes (multiple per position)
  • Action embeddings (multi-hot vector of engagement types)
  • Product surface (e.g., Home, Search, Profile)
def block_history_reduce(
    history_post_hashes: jnp.ndarray,  # [B, S, num_item_hashes]
    history_post_embeddings: jnp.ndarray,  # [B, S, num_item_hashes, D]
    history_author_embeddings: jnp.ndarray,  # [B, S, num_author_hashes, D]
    history_product_surface_embeddings: jnp.ndarray,  # [B, S, D]
    history_actions_embeddings: jnp.ndarray,  # [B, S, D]
    ...
) -> Tuple[jax.Array, jax.Array]:
    """Combine history embeddings into sequence.

    Returns:
        history_embeddings: [B, S, D]
        history_padding_mask: [B, S]
    """
    # Concatenate all features
    post_author_embedding = jnp.concatenate(
        [
            history_post_embeddings_reshaped,
            history_author_embeddings_reshaped,
            history_actions_embeddings,
            history_product_surface_embeddings,
        ],
        axis=-1,
    )

    # Project to embedding dimension
    history_embedding = jnp.dot(post_author_embedding, proj_mat_3)

    return history_embedding, history_padding_mask
Source: recsys_model.py:122-182

3. Candidate Embeddings

Similar to history, but without action embeddings (since we’re predicting actions):
def block_candidate_reduce(
    candidate_post_hashes: jnp.ndarray,  # [B, C, num_item_hashes]
    candidate_post_embeddings: jnp.ndarray,  # [B, C, num_item_hashes, D]
    candidate_author_embeddings: jnp.ndarray,  # [B, C, num_author_hashes, D]
    candidate_product_surface_embeddings: jnp.ndarray,  # [B, C, D]
    ...
) -> Tuple[jax.Array, jax.Array]:
    """Combine candidate embeddings into sequence.

    Returns:
        candidate_embeddings: [B, C, D]
        candidate_padding_mask: [B, C]
    """
    post_author_embedding = jnp.concatenate(
        [
            candidate_post_embeddings_reshaped,
            candidate_author_embeddings_reshaped,
            candidate_product_surface_embeddings,
        ],
        axis=-1,
    )

    candidate_embedding = jnp.dot(post_author_embedding, proj_mat_2)

    return candidate_embedding, candidate_padding_mask
Source: recsys_model.py:185-242

Multi-Action Prediction

The model predicts multiple engagement types simultaneously:
def _get_unembedding(self) -> jax.Array:
    """Get the unembedding matrix for decoding to logits."""
    unembed_mat = hk.get_parameter(
        "unembeddings",
        [config.emb_size, config.num_actions],
        dtype=jnp.float32,
        init=embed_init,
    )
    return unembed_mat

# Usage:
logits = jnp.dot(candidate_embeddings, unembeddings)
# Shape: [B, num_candidates, num_actions]
Each action dimension corresponds to a different engagement type:
  • Like probability
  • Repost probability
  • Reply probability
  • Click probability
  • Negative feedback probability
  • etc.
The final ranking score is typically a weighted combination of these action logits, where weights reflect business objectives (e.g., replies might be weighted higher than likes).

Hash-Based Feature Encoding

Phoenix uses multiple hash functions for embedding lookup:
Multiple hash functions reduce hash collisions and provide richer representations:
  • Each feature (user, post, author) is hashed multiple times
  • Each hash produces a separate embedding lookup
  • Embeddings are concatenated and projected to the final dimension
This approach works well for sparse, high-cardinality features common in recommendation systems.
@dataclass
class HashConfig:
    num_user_hashes: int = 2
    num_item_hashes: int = 2
    num_author_hashes: int = 2
Source: recsys_model.py:32-38
Hash value 0 is reserved for padding. The padding mask is computed by checking if the first hash is non-zero:
padding_mask = (hashes[:, :, 0] != 0)

Model Configuration

@dataclass
class PhoenixModelConfig:
    """Configuration for the recommendation system model."""

    model: TransformerConfig
    emb_size: int
    num_actions: int
    history_seq_len: int = 128
    candidate_seq_len: int = 32

    hash_config: HashConfig = None
    product_surface_vocab_size: int = 16
    fprop_dtype: Any = jnp.bfloat16
Source: recsys_model.py:245-261

Key Differences from Retrieval

Retrieval

  • Two separate towers
  • User tower only
  • Output: normalized embeddings
  • Objective: similarity search

Ranking

  • Single joint model
  • User + history + candidates
  • Output: action logits
  • Objective: engagement prediction

Next Steps

Architecture Details

Deep dive into attention masking implementation

Retrieval Model

Learn about the two-tower retrieval stage

Build docs developers (and LLMs) love