Overview
Phoenix is the machine learning system that powers both retrieval (finding relevant candidates from millions of posts) and ranking (scoring and ordering candidates by predicted engagement).
The system uses transformer-based architectures adapted from the Grok-1 open source release by xAI, with custom input embeddings and attention masking designed specifically for recommendation systems.
The code is representative of the production model with the exception of specific scaling optimizations.
Two-Stage Architecture
Phoenix operates in two distinct stages:
Retrieval Two-Tower Model Narrows millions of posts to hundreds using approximate nearest neighbor search
User Tower: Encodes user + engagement history
Candidate Tower: Encodes all posts
Similarity: Dot product for top-K selection
Ranking Transformer with Candidate Isolation Scores retrieved candidates using full transformer
Input: User context + candidate posts
Attention: Candidates isolated from each other
Output: Probabilities for multiple engagement types
┌──────────────────────────────────────────────────────────────┐
│ PHOENIX RECOMMENDATION PIPELINE │
├──────────────────────────────────────────────────────────────┤
│ │
│ ┌────────┐ ┌─────────────────┐ ┌──────────────────┐ │
│ │ User │──▶│ STAGE 1: │──▶│ STAGE 2: │──▶ │
│ │Request │ │ RETRIEVAL │ │ RANKING │ │
│ └────────┘ │ (Two-Tower) │ │ (Transformer) │ │
│ │ Millions→1000s │ │ 1000s→Ranked │ │
│ └─────────────────┘ └──────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────┘
Retrieval: Two-Tower Model
The retrieval stage efficiently finds relevant candidates from a massive corpus.
Architecture
User Tower encodes user features and engagement history:
phoenix/recsys_retrieval_model.py
class UserTower ( hk . Module ):
"""User tower that processes engagement history into user representation."""
def __call__ ( self , batch , embeddings ):
# Combine user and history embeddings
user_embedding, user_mask = block_user_reduce(
batch.user_hashes,
embeddings.user_embeddings,
num_user_hashes = self .hash_config.num_user_hashes,
emb_size = self .transformer_config.emb_size,
)
history_embedding, history_mask = block_history_reduce(
batch.history_post_hashes,
embeddings.history_post_embeddings,
# ... encode posts, authors, actions
)
# Pass through transformer
transformer_input = jnp.concatenate([user_embedding, history_embedding], axis = 1 )
output = self .transformer(transformer_input, padding_mask)
# Extract and normalize user representation
user_representation = output[:, 0 , :] # First position = user
return normalize(user_representation)
Candidate Tower projects post + author embeddings:
class CandidateTower ( hk . Module ):
"""Candidate tower that encodes posts into shared embedding space."""
def __call__ ( self , post_author_embedding ):
# Two-layer MLP with SiLU activation
hidden = jnp.dot(post_author_embedding, proj_1)
hidden = jax.nn.silu(hidden)
candidate_embeddings = jnp.dot(hidden, proj_2)
# L2 normalization
return normalize(candidate_embeddings)
Similarity Search
Once both towers produce normalized embeddings:
Index building : All posts encoded offline into [N, D] matrix
Query encoding : User tower produces [B, D] embedding at request time
Top-K retrieval : Dot product similarity → select top candidates
# Similarity scores via dot product (since normalized, this is cosine similarity)
scores = user_representation @ candidate_embeddings.T # [B, N]
top_k_indices = jnp.argsort(scores, axis =- 1 )[ ... , - K:] # Top K
Because both representations are L2-normalized, dot product equals cosine similarity, enabling efficient approximate nearest neighbor search with libraries like FAISS or ScaNN.
The ranking model scores the retrieved candidates using a full transformer architecture with a critical design choice: candidates cannot attend to each other .
Model Architecture
class RecsysModel ( hk . Module ):
"""Recommendation model for ranking candidates."""
def __call__ ( self , batch , embeddings ):
# 1. Reduce hash embeddings
user_embedding, user_mask = block_user_reduce( ... )
history_embedding, history_mask = block_history_reduce( ... )
candidate_embedding, candidate_mask = block_candidate_reduce( ... )
# 2. Concatenate sequence: [user, history, candidates]
transformer_input = jnp.concatenate([
user_embedding, # [B, 1, D]
history_embedding, # [B, S, D]
candidate_embedding, # [B, C, D]
], axis = 1 )
# 3. Create candidate isolation mask
padding_mask = create_candidate_isolation_mask(
user_mask, history_mask, candidate_mask
)
# 4. Transform with special attention masking
outputs = self .transformer(
transformer_input,
padding_mask = padding_mask,
)
# 5. Extract candidate outputs and predict actions
candidate_outputs = outputs[:, history_len:, :] # [B, C, D]
logits = self .unembedding_layer(candidate_outputs) # [B, C, num_actions]
return RecsysModelOutput( logits = logits)
Candidate Isolation Mask
The attention mask ensures candidates only attend to user/history, never to each other:
Keys (what we attend TO)
────────────────────────────────────────────────▶
│ User │ History (S) │ Candidates (C) │
┌───┼──────┼───────────────┼────────────────────┤
│ │ │ │ │
│ U │ ✓ │ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ │
│ │ │ │ │
├───┼──────┼───────────────┼────────────────────┤
Q│ │ │ │ │
u│ H │ ✓ │ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ │
e│ i │ ✓ │ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ │
r│ s │ ✓ │ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ │
i│ │ │ │ │
e├───┼──────┼───────────────┼────────────────────┤
s│ │ │ │ Diagonal only! │
││ C │ ✓ │ ✓ ✓ ✓ │ ✓ ✗ ✗ ✗ │
││ a │ ✓ │ ✓ ✓ ✓ │ ✗ ✓ ✗ ✗ │
▼│ n │ ✓ │ ✓ ✓ ✓ │ ✗ ✗ ✓ ✗ │
│ d │ ✓ │ ✓ ✓ ✓ │ ✗ ✗ ✗ ✓ │
│ │ │ │ │
└───┴──────┴───────────────┴────────────────────┘
✓ = Can attend (1) ✗ = Cannot attend (0)
Why Candidate Isolation? Candidates are prevented from attending to each other to ensure score independence : the score for a post doesn’t depend on which other posts are in the batch. This makes scores consistent and cacheable across different batches.
Multi-Action Prediction
The model predicts probabilities for multiple engagement types simultaneously:
# Output shape: [B, num_candidates, num_actions]
logits = model(batch, embeddings)
# Actions predicted:
P(favorite)
P(reply)
P(repost)
P(quote)
P(click)
P(profile_click)
P(video_view)
P(photo_expand)
P(share)
P(dwell)
P(follow_author)
P(not_interested) # Negative signal
P(block_author) # Negative signal
P(mute_author) # Negative signal
P(report) # Negative signal
Hash-Based Embeddings
Both retrieval and ranking use multiple hash functions for embedding lookup:
class HashConfig :
num_user_hashes: int = 2 # User ID → 2 hash functions
num_item_hashes: int = 2 # Post ID → 2 hash functions
num_author_hashes: int = 2 # Author ID → 2 hash functions
Each entity is hashed multiple times, and the resulting embeddings are combined:
def block_user_reduce ( user_hashes , user_embeddings , ...):
# user_hashes: [B, num_user_hashes]
# user_embeddings: [B, num_user_hashes, D]
# Learn a projection to combine hash embeddings
projection = hk.get_parameter( ... )
combined = apply_projection(user_embeddings, projection)
return combined # [B, 1, D]
Multiple hash functions provide better representation capacity and collision resistance compared to a single hash table.
Integration with Home Mixer
Phoenix Source (Retrieval)
home-mixer/sources/phoenix_source.rs
pub struct PhoenixSource {
pub phoenix_retrieval_client : Arc < dyn PhoenixRetrievalClient >,
}
impl Source < ScoredPostsQuery , PostCandidate > for PhoenixSource {
async fn get_candidates ( & self , query : & ScoredPostsQuery ) -> Result < Vec < PostCandidate >> {
let sequence = query . user_action_sequence . as_ref () ? ;
let response = self . phoenix_retrieval_client
. retrieve ( query . user_id, sequence . clone (), MAX_RESULTS )
. await ? ;
let candidates = response . top_k_candidates
. into_iter ()
. map ( | c | PostCandidate {
tweet_id : c . tweet_id,
author_id : c . author_id,
// ...
})
. collect ();
Ok ( candidates )
}
}
Phoenix Scorer (Ranking)
home-mixer/scorers/phoenix_scorer.rs
pub struct PhoenixScorer {
pub phoenix_client : Arc < dyn PhoenixPredictionClient >,
}
impl Scorer < ScoredPostsQuery , PostCandidate > for PhoenixScorer {
async fn score ( & self , query : & ScoredPostsQuery , candidates : & [ PostCandidate ])
-> Result < Vec < PostCandidate >> {
let tweet_infos = candidates . iter () . map ( | c | TweetInfo {
tweet_id : c . tweet_id,
author_id : c . author_id,
// ...
}) . collect ();
let response = self . phoenix_client
. predict ( query . user_id, sequence , tweet_infos )
. await ? ;
// Extract predictions and update candidates
let scored_candidates = candidates . iter () . map ( | c | {
let phoenix_scores = extract_scores ( & response , c . tweet_id);
PostCandidate {
phoenix_scores ,
.. c . clone ()
}
}) . collect ();
Ok ( scored_candidates )
}
}
Running the Code
The repository includes example code demonstrating both retrieval and ranking:
Key Design Decisions
Prevents the score for a candidate from depending on which other candidates are in the batch. This ensures:
Consistent scores across different batches
Ability to cache predictions
Simpler debugging and analysis
Why Hash-Based Embeddings?
Multiple hash functions provide:
Better representation capacity than single lookup
Collision resistance for large entity spaces
Memory efficiency compared to explicit embedding tables
Why Multi-Action Prediction?
Rather than predicting a single “relevance” score, the model predicts probabilities for many actions:
Captures nuanced user preferences
Enables flexible weighting strategies
Incorporates negative signals (block, mute, report)
Why Two-Stage (Retrieval + Ranking)?
Retrieval : Fast, approximate search over millions of items
Ranking : Expensive, precise scoring for hundreds of items
This separation enables scaling to large corpora while maintaining quality
Typical Latencies
Retrieval (Two-Tower): ~20-50ms for top-1000 from millions
Ranking (Transformer): ~50-100ms for scoring 500 candidates
Total Phoenix latency: ~70-150ms
Home Mixer Orchestration layer that uses Phoenix for candidate sourcing and scoring
Thunder Provides in-network candidates to complement Phoenix’s out-of-network retrieval
Candidate Pipeline Framework that integrates Phoenix into the overall recommendation flow