Learn how Phoenix uses a two-tower architecture to efficiently retrieve relevant candidates from millions of items
The retrieval stage is the first phase of the Phoenix recommendation pipeline. It efficiently narrows down millions of potential candidates to hundreds of relevant items using a two-tower architecture that enables fast similarity search.
The user tower leverages the same transformer architecture used in ranking:
def build_user_representation( self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings,) -> Tuple[jax.Array, jax.Array]: """Build user representation from user features and history. Uses the Phoenix transformer to encode user + history embeddings into a single user representation vector. Returns: user_representation: L2-normalized user embedding [B, D] user_norm: Pre-normalization L2 norm [B, 1] """ # Combine user and history embeddings embeddings = jnp.concatenate([user_embeddings, history_embeddings], axis=1) padding_mask = jnp.concatenate([user_padding_mask, history_padding_mask], axis=1) # Pass through transformer model_output = self.model( embeddings.astype(self.fprop_dtype), padding_mask, candidate_start_offset=None, ) # Average pool over valid positions user_outputs = model_output.embeddings mask_float = padding_mask.astype(jnp.float32)[:, :, None] user_embeddings_masked = user_outputs * mask_float user_embedding_sum = jnp.sum(user_embeddings_masked, axis=1) mask_sum = jnp.sum(mask_float, axis=1) user_representation = user_embedding_sum / jnp.maximum(mask_sum, 1.0) # L2 normalize user_norm_sq = jnp.sum(user_representation**2, axis=-1, keepdims=True) user_norm = jnp.sqrt(jnp.maximum(user_norm_sq, EPS)) user_representation = user_representation / user_norm return user_representation, user_norm
The user tower uses average pooling over the transformer outputs, weighted by the padding mask. This creates a single vector representation that captures the full user context.
Once both towers produce normalized embeddings, retrieval becomes a simple dot product:
def _retrieve_top_k( self, user_representation: jax.Array, # [B, D] corpus_embeddings: jax.Array, # [N, D] top_k: int, corpus_mask: Optional[jax.Array] = None,) -> Tuple[jax.Array, jax.Array]: """Retrieve top-k candidates from a corpus for each user. Returns: top_k_indices: [B, K] indices of top-k candidates top_k_scores: [B, K] similarity scores of top-k candidates """ # Compute similarity scores scores = jnp.matmul(user_representation, corpus_embeddings.T) # [B, N] # Apply corpus mask if provided if corpus_mask is not None: scores = jnp.where(corpus_mask[None, :], scores, -INF) # Select top-k top_k_scores, top_k_indices = jax.lax.top_k(scores, top_k) return top_k_indices, top_k_scores
Why L2 normalization?Normalizing embeddings to unit length converts cosine similarity into a simple dot product. This enables the use of highly optimized approximate nearest neighbor (ANN) libraries like FAISS or ScaNN for efficient retrieval at scale.
Candidate embeddings are pre-computed offline and stored in a vector database. Only the user tower runs at inference time.
ANN Index
In production, exact top-k search is replaced with approximate nearest neighbor algorithms (e.g., FAISS, ScaNN) that provide sub-linear search complexity.
Batch Processing
Both towers support batched computation for efficient training and offline candidate encoding.
@dataclassclass PhoenixRetrievalModelConfig: """Configuration for the Phoenix Retrieval Model. This model uses the same transformer architecture as the Phoenix ranker for encoding user representations. """ model: TransformerConfig emb_size: int history_seq_len: int = 128 candidate_seq_len: int = 32 hash_config: HashConfig = None product_surface_vocab_size: int = 16 fprop_dtype: Any = jnp.bfloat16