Skip to main content
The retrieval stage is the first phase of the Phoenix recommendation pipeline. It efficiently narrows down millions of potential candidates to hundreds of relevant items using a two-tower architecture that enables fast similarity search.

Architecture Overview

Two-tower architecture diagram The two-tower model consists of:
  1. User Tower: Encodes user features and engagement history
  2. Candidate Tower: Encodes candidate item features
  3. Similarity Search: Retrieves top-K candidates using dot product

How It Works

1

User Representation

The user tower processes:
  • User identifiers (via hash embeddings)
  • Recent engagement history (posts, authors, actions)
  • Product surface context
These are passed through the Phoenix transformer to produce a normalized user embedding [B, D]
2

Candidate Representation

The candidate tower projects post and author embeddings through a 2-layer MLP:
  • Layer 1: Projects to 2*D dimensions with SiLU activation
  • Layer 2: Projects to D dimensions
  • L2 normalization produces the final candidate embedding
3

Similarity Search

Compute dot product between user and all candidate embeddings:
scores = user_embedding @ candidate_embeddings.T  # [B, N]
top_k_indices = top_k(scores, k)  # [B, K]

User Tower Implementation

The user tower leverages the same transformer architecture used in ranking:
def build_user_representation(
    self,
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]:
    """Build user representation from user features and history.

    Uses the Phoenix transformer to encode user + history embeddings
    into a single user representation vector.

    Returns:
        user_representation: L2-normalized user embedding [B, D]
        user_norm: Pre-normalization L2 norm [B, 1]
    """
    # Combine user and history embeddings
    embeddings = jnp.concatenate([user_embeddings, history_embeddings], axis=1)
    padding_mask = jnp.concatenate([user_padding_mask, history_padding_mask], axis=1)

    # Pass through transformer
    model_output = self.model(
        embeddings.astype(self.fprop_dtype),
        padding_mask,
        candidate_start_offset=None,
    )

    # Average pool over valid positions
    user_outputs = model_output.embeddings
    mask_float = padding_mask.astype(jnp.float32)[:, :, None]
    user_embeddings_masked = user_outputs * mask_float
    user_embedding_sum = jnp.sum(user_embeddings_masked, axis=1)
    mask_sum = jnp.sum(mask_float, axis=1)
    user_representation = user_embedding_sum / jnp.maximum(mask_sum, 1.0)

    # L2 normalize
    user_norm_sq = jnp.sum(user_representation**2, axis=-1, keepdims=True)
    user_norm = jnp.sqrt(jnp.maximum(user_norm_sq, EPS))
    user_representation = user_representation / user_norm

    return user_representation, user_norm
The user tower uses average pooling over the transformer outputs, weighted by the padding mask. This creates a single vector representation that captures the full user context.

Candidate Tower Implementation

The candidate tower is a simpler 2-layer MLP that projects combined post and author embeddings:
class CandidateTower(hk.Module):
    """Candidate tower that projects post+author embeddings to a shared embedding space.

    This tower takes the concatenated embeddings of a post and its author,
    and projects them to a normalized representation suitable for similarity search.
    """

    def __call__(self, post_author_embedding: jax.Array) -> jax.Array:
        """Project post+author embeddings to normalized representation.

        Args:
            post_author_embedding: Concatenated post and author embeddings
                Shape: [B, C, num_hashes, D] or [B, num_hashes, D]

        Returns:
            Normalized candidate representation
                Shape: [B, C, D] or [B, D]
        """
        # Reshape to flatten hash embeddings
        post_author_embedding = jnp.reshape(post_author_embedding, (B, C, -1))

        # Two-layer MLP with SiLU activation
        hidden = jnp.dot(post_author_embedding, proj_1)  # -> [B, C, 2*D]
        hidden = jax.nn.silu(hidden)
        candidate_embeddings = jnp.dot(hidden, proj_2)  # -> [B, C, D]

        # L2 normalize
        candidate_norm_sq = jnp.sum(candidate_embeddings**2, axis=-1, keepdims=True)
        candidate_norm = jnp.sqrt(jnp.maximum(candidate_norm_sq, EPS))
        candidate_representation = candidate_embeddings / candidate_norm

        return candidate_representation
Once both towers produce normalized embeddings, retrieval becomes a simple dot product:
def _retrieve_top_k(
    self,
    user_representation: jax.Array,  # [B, D]
    corpus_embeddings: jax.Array,    # [N, D]
    top_k: int,
    corpus_mask: Optional[jax.Array] = None,
) -> Tuple[jax.Array, jax.Array]:
    """Retrieve top-k candidates from a corpus for each user.

    Returns:
        top_k_indices: [B, K] indices of top-k candidates
        top_k_scores: [B, K] similarity scores of top-k candidates
    """
    # Compute similarity scores
    scores = jnp.matmul(user_representation, corpus_embeddings.T)  # [B, N]

    # Apply corpus mask if provided
    if corpus_mask is not None:
        scores = jnp.where(corpus_mask[None, :], scores, -INF)

    # Select top-k
    top_k_scores, top_k_indices = jax.lax.top_k(scores, top_k)

    return top_k_indices, top_k_scores
Why L2 normalization?Normalizing embeddings to unit length converts cosine similarity into a simple dot product. This enables the use of highly optimized approximate nearest neighbor (ANN) libraries like FAISS or ScaNN for efficient retrieval at scale.

Key Design Decisions

Shared Transformer Architecture

The user tower uses the same transformer architecture as the Phoenix ranking model. This provides several benefits:
  • Consistent representations across retrieval and ranking
  • Transfer learning from ranking to retrieval
  • Simplified infrastructure with shared model code

Asymmetric Tower Complexity

User Tower

Heavy: Full transformerComputed once per user request and cached

Candidate Tower

Light: 2-layer MLPPre-computed for all items offline
This asymmetry is intentional:
  • The user tower can be expensive because it runs once per request
  • The candidate tower must be lightweight because it runs on millions of items

Performance Considerations

Candidate embeddings are pre-computed offline and stored in a vector database. Only the user tower runs at inference time.
In production, exact top-k search is replaced with approximate nearest neighbor algorithms (e.g., FAISS, ScaNN) that provide sub-linear search complexity.
Both towers support batched computation for efficient training and offline candidate encoding.

Model Configuration

@dataclass
class PhoenixRetrievalModelConfig:
    """Configuration for the Phoenix Retrieval Model.

    This model uses the same transformer architecture as the Phoenix ranker
    for encoding user representations.
    """

    model: TransformerConfig
    emb_size: int
    history_seq_len: int = 128
    candidate_seq_len: int = 32

    hash_config: HashConfig = None
    product_surface_vocab_size: int = 16
    fprop_dtype: Any = jnp.bfloat16
Source: recsys_retrieval_model.py:103-121

Next Steps

Ranking Model

Learn how retrieved candidates are ranked

Architecture Details

Explore transformer implementation details

Build docs developers (and LLMs) love