The Phoenix ranking model is a transformer-based neural network that scores candidates retrieved by the retrieval stage. It predicts multiple engagement actions (likes, reposts, replies, etc.) for each candidate while ensuring that scores are independent of other items in the batch.
Model Architecture
PHOENIX RANKING MODEL
┌────────────────────────────────────────────────────────────────────────────┐
│ │
│ OUTPUT LOGITS │
│ [B, num_candidates, num_actions] │
│ │ │
│ │ Unembedding │
│ │ Projection │
│ │ │
│ ┌───────────────┴───────────────┐ │
│ │ │ │
│ │ Extract Candidate Outputs │ │
│ │ (positions after history) │ │
│ │ │ │
│ └───────────────┬───────────────┘ │
│ │ │
│ ┌───────────────┴───────────────┐ │
│ │ │ │
│ │ Transformer │ │
│ │ (with special masking) │ │
│ │ │ │
│ │ Candidates CANNOT attend │ │
│ │ to each other │ │
│ │ │ │
│ └───────────────┬───────────────┘ │
│ │ │
│ ┌───────────────────────────────┼───────────────────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────┐ ┌─────────────────┐ ┌────────────┐ │
│ │ User │ │ History │ │ Candidates │ │
│ │Embedding │ │ Embeddings │ │ Embeddings │ │
│ │ [B, 1] │ │ [B, S, D] │ │ [B, C, D] │ │
│ │ │ │ │ │ │ │
│ │ User │ │ Posts + Authors │ │ Posts + │ │
│ │ Hashes │ │ + Actions + │ │ Authors + │ │
│ │ │ │ Product Surface │ │ Product │ │
│ └──────────┘ └─────────────────┘ │ Surface │ │
│ └────────────┘ │
│ │
└────────────────────────────────────────────────────────────────────────────┘
Key Design: Candidate Isolation
Critical Design Requirement : The score for a candidate must not depend on which other candidates are in the batch. This ensures consistent scoring regardless of batch composition.
Candidate isolation is achieved through custom attention masking that prevents candidates from attending to each other while still allowing them to attend to user and history context.
Why Candidate Isolation Matters
Without isolation, the model could:
Score candidates differently based on which other items are in the batch
Create inconsistent rankings across different batch compositions
Make A/B testing and experimentation unreliable
Cause unexpected behavior when batch sizes or sampling strategies change
With isolation, each candidate receives a deterministic score based solely on user context.
Forward Pass Implementation
Here’s how the ranking model processes inputs:
def __call__ (
self ,
batch : RecsysBatch,
recsys_embeddings : RecsysEmbeddings,
) -> RecsysModelOutput:
"""Forward pass for ranking candidates.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
RecsysModelOutput containing logits for each candidate.
Shape = [B, num_candidates, num_actions]
"""
# Build input embeddings: [B, 1 + S + C, D]
embeddings, padding_mask, candidate_start_offset = self .build_inputs(
batch, recsys_embeddings
)
# Pass through transformer with candidate isolation
model_output = self .model(
embeddings,
padding_mask,
candidate_start_offset = candidate_start_offset, # Enables special masking
)
out_embeddings = model_output.embeddings
out_embeddings = layer_norm(out_embeddings)
# Extract only candidate outputs (discard user + history)
candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]
# Project to action logits
unembeddings = self ._get_unembedding() # [D, num_actions]
logits = jnp.dot(candidate_embeddings, unembeddings) # [B, C, num_actions]
return RecsysModelOutput( logits = logits)
Source: recsys_model.py:439-474
The model combines multiple feature types into input embeddings:
1. User Embedding
def block_user_reduce (
user_hashes : jnp.ndarray, # [B, num_user_hashes]
user_embeddings : jnp.ndarray, # [B, num_user_hashes, D]
num_user_hashes : int ,
emb_size : int ,
) -> Tuple[jax.Array, jax.Array]:
"""Combine multiple user hash embeddings into a single user representation.
Returns:
user_embedding: [B, 1, D] - combined user embedding
user_padding_mask: [B, 1] - True where user is valid
"""
B, D = user_embeddings.shape[ 0 ], emb_size
# Flatten multiple hash embeddings
user_embedding = user_embeddings.reshape((B, 1 , num_user_hashes * D))
# Project to embedding dimension
user_embedding = jnp.dot(user_embedding, proj_mat_1) # [B, 1, D]
# Create padding mask (hash 0 is reserved for padding)
user_padding_mask = (user_hashes[:, 0 ] != 0 ).reshape(B, 1 )
return user_embedding, user_padding_mask
Source: recsys_model.py:79-119
2. History Embeddings
History embeddings combine:
Post hashes (multiple per position)
Author hashes (multiple per position)
Action embeddings (multi-hot vector of engagement types)
Product surface (e.g., Home, Search, Profile)
def block_history_reduce (
history_post_hashes : jnp.ndarray, # [B, S, num_item_hashes]
history_post_embeddings : jnp.ndarray, # [B, S, num_item_hashes, D]
history_author_embeddings : jnp.ndarray, # [B, S, num_author_hashes, D]
history_product_surface_embeddings : jnp.ndarray, # [B, S, D]
history_actions_embeddings : jnp.ndarray, # [B, S, D]
...
) -> Tuple[jax.Array, jax.Array]:
"""Combine history embeddings into sequence.
Returns:
history_embeddings: [B, S, D]
history_padding_mask: [B, S]
"""
# Concatenate all features
post_author_embedding = jnp.concatenate(
[
history_post_embeddings_reshaped,
history_author_embeddings_reshaped,
history_actions_embeddings,
history_product_surface_embeddings,
],
axis =- 1 ,
)
# Project to embedding dimension
history_embedding = jnp.dot(post_author_embedding, proj_mat_3)
return history_embedding, history_padding_mask
Source: recsys_model.py:122-182
3. Candidate Embeddings
Similar to history, but without action embeddings (since we’re predicting actions):
def block_candidate_reduce (
candidate_post_hashes : jnp.ndarray, # [B, C, num_item_hashes]
candidate_post_embeddings : jnp.ndarray, # [B, C, num_item_hashes, D]
candidate_author_embeddings : jnp.ndarray, # [B, C, num_author_hashes, D]
candidate_product_surface_embeddings : jnp.ndarray, # [B, C, D]
...
) -> Tuple[jax.Array, jax.Array]:
"""Combine candidate embeddings into sequence.
Returns:
candidate_embeddings: [B, C, D]
candidate_padding_mask: [B, C]
"""
post_author_embedding = jnp.concatenate(
[
candidate_post_embeddings_reshaped,
candidate_author_embeddings_reshaped,
candidate_product_surface_embeddings,
],
axis =- 1 ,
)
candidate_embedding = jnp.dot(post_author_embedding, proj_mat_2)
return candidate_embedding, candidate_padding_mask
Source: recsys_model.py:185-242
Multi-Action Prediction
The model predicts multiple engagement types simultaneously:
def _get_unembedding ( self ) -> jax.Array:
"""Get the unembedding matrix for decoding to logits."""
unembed_mat = hk.get_parameter(
"unembeddings" ,
[config.emb_size, config.num_actions],
dtype = jnp.float32,
init = embed_init,
)
return unembed_mat
# Usage:
logits = jnp.dot(candidate_embeddings, unembeddings)
# Shape: [B, num_candidates, num_actions]
Each action dimension corresponds to a different engagement type:
Like probability
Repost probability
Reply probability
Click probability
Negative feedback probability
etc.
The final ranking score is typically a weighted combination of these action logits, where weights reflect business objectives (e.g., replies might be weighted higher than likes).
Hash-Based Feature Encoding
Phoenix uses multiple hash functions for embedding lookup:
Why Multiple Hash Functions?
Multiple hash functions reduce hash collisions and provide richer representations:
Each feature (user, post, author) is hashed multiple times
Each hash produces a separate embedding lookup
Embeddings are concatenated and projected to the final dimension
This approach works well for sparse, high-cardinality features common in recommendation systems.
@dataclass
class HashConfig :
num_user_hashes: int = 2
num_item_hashes: int = 2
num_author_hashes: int = 2
Source: recsys_model.py:32-38
Hash value 0 is reserved for padding. The padding mask is computed by checking if the first hash is non-zero: padding_mask = (hashes[:, :, 0 ] != 0 )
Model Configuration
@dataclass
class PhoenixModelConfig :
"""Configuration for the recommendation system model."""
model: TransformerConfig
emb_size: int
num_actions: int
history_seq_len: int = 128
candidate_seq_len: int = 32
hash_config: HashConfig = None
product_surface_vocab_size: int = 16
fprop_dtype: Any = jnp.bfloat16
Source: recsys_model.py:245-261
Key Differences from Retrieval
Retrieval
Two separate towers
User tower only
Output: normalized embeddings
Objective: similarity search
Ranking
Single joint model
User + history + candidates
Output: action logits
Objective: engagement prediction
Next Steps
Architecture Details Deep dive into attention masking implementation
Retrieval Model Learn about the two-tower retrieval stage