Skip to main content

Overview

The PhoenixRetrievalModel implements a two-tower architecture for efficient retrieval of relevant posts. It uses the Phoenix transformer to encode user representations and a separate candidate tower for post embeddings, enabling fast approximate nearest neighbor (ANN) search.

Classes

PhoenixRetrievalModel

A two-tower retrieval model using the Phoenix transformer for user encoding.
model
Transformer
required
The underlying Grok transformer model for user encoding
config
PhoenixRetrievalModelConfig
required
Model configuration including embedding size and sequence lengths
fprop_dtype
Any
default:"jnp.bfloat16"
Forward propagation data type for computation
name
Optional[str]
Optional name for the model

PhoenixRetrievalModelConfig

Configuration for the Phoenix Retrieval Model. Uses the same transformer architecture as the Phoenix ranker for encoding user representations.
model
TransformerConfig
required
Transformer architecture configuration
emb_size
int
required
Embedding dimension size
history_seq_len
int
default:"128"
Maximum length of user history sequence
candidate_seq_len
int
default:"32"
Maximum number of candidate posts (used for batch processing)
name
Optional[str]
Optional model name
fprop_dtype
Any
default:"jnp.bfloat16"
Forward propagation data type
hash_config
HashConfig
Hash configuration for multi-hash embeddings
product_surface_vocab_size
int
default:"16"
Size of product surface vocabulary

CandidateTower

Candidate tower that projects post+author embeddings to a shared embedding space.
emb_size
int
required
Output embedding dimension
name
Optional[str]
Optional name for the tower

Named Tuples

RetrievalOutput

Output of the retrieval model.
user_representation
jax.Array
L2-normalized user embedding [B, D]
top_k_indices
jax.Array
Indices of top-k retrieved candidates [B, K]
top_k_scores
jax.Array
Similarity scores for top-k candidates [B, K]

Methods

__call__

def __call__(
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
    corpus_embeddings: jax.Array,
    top_k: int,
    corpus_mask: Optional[jax.Array] = None,
) -> RetrievalOutput
Retrieve top-k candidates from corpus for each user.
batch
RecsysBatch
required
Batch containing hashes, actions, and product surfaces
recsys_embeddings
RecsysEmbeddings
required
Pre-looked-up embeddings from embedding tables
corpus_embeddings
jax.Array
required
Normalized corpus candidate embeddings [N, D]
top_k
int
required
Number of candidates to retrieve
corpus_mask
Optional[jax.Array]
Optional mask for valid corpus entries [N]
output
RetrievalOutput
Retrieval output containing user representation and top-k results

build_user_representation

def build_user_representation(
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]
Build user representation from user features and history using the Phoenix transformer.
batch
RecsysBatch
required
Batch containing hashes, actions, and product surfaces
recsys_embeddings
RecsysEmbeddings
required
Pre-looked-up embeddings from embedding tables
user_representation
jax.Array
L2-normalized user embedding [B, D]
user_norm
jax.Array
Pre-normalization L2 norm [B, 1]

build_candidate_representation

def build_candidate_representation(
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]
Build candidate (item) representations by projecting post + author embeddings to a shared space.
batch
RecsysBatch
required
Batch containing candidate hashes
recsys_embeddings
RecsysEmbeddings
required
Pre-looked-up embeddings containing candidate embeddings
candidate_representation
jax.Array
L2-normalized candidate embeddings [B, C, D]
candidate_padding_mask
jax.Array
Valid candidate mask [B, C]

_retrieve_top_k

def _retrieve_top_k(
    user_representation: jax.Array,
    corpus_embeddings: jax.Array,
    top_k: int,
    corpus_mask: Optional[jax.Array] = None,
) -> Tuple[jax.Array, jax.Array]
Retrieve top-k candidates from a corpus for each user using dot product similarity.
user_representation
jax.Array
required
Normalized user embeddings [B, D]
corpus_embeddings
jax.Array
required
Normalized corpus candidate embeddings [N, D]
top_k
int
required
Number of candidates to retrieve
corpus_mask
Optional[jax.Array]
Optional mask for valid corpus entries [N]
top_k_indices
jax.Array
Indices of top-k candidates [B, K]
top_k_scores
jax.Array
Similarity scores of top-k candidates [B, K]

CandidateTower Methods

__call__

def __call__(
    post_author_embedding: jax.Array,
) -> jax.Array
Project post+author embeddings to normalized representation.
post_author_embedding
jax.Array
required
Concatenated post and author embeddings. Shape: [B, C, num_hashes, D] or [B, num_hashes, D]
output
jax.Array
Normalized candidate representation. Shape: [B, C, D] or [B, D]

Architecture

The two-tower architecture consists of:
  1. User Tower: Encodes user features + history using the Phoenix transformer
    • Processes user embeddings and interaction history
    • Outputs a single L2-normalized user representation vector
  2. Candidate Tower: Projects candidate embeddings to a shared space
    • Two-layer MLP with SiLU activation
    • Projects concatenated post+author embeddings
    • Outputs L2-normalized candidate representations
  3. Retrieval: Uses dot product similarity between normalized embeddings
    • Enables efficient ANN search in production
    • Returns top-k candidates based on similarity scores

Usage Notes

  • Both user and candidate representations are L2-normalized
  • Dot product similarity is used for retrieval (equivalent to cosine similarity for normalized vectors)
  • The model is designed for efficient ANN search using libraries like FAISS or ScaNN
  • Constants: EPS = 1e-12, INF = 1e12 for numerical stability

Build docs developers (and LLMs) love