Skip to main content

Overview

Phoenix uses a hash-based embedding approach for all categorical features (users, posts, authors). Instead of maintaining a single massive embedding table, the system uses multiple hash functions to map each ID to multiple embedding vectors, which are then combined through learned projections.

Why Hash-based Embeddings?

Traditional embedding approaches face scalability challenges:
With billions of users and posts, a single embedding table would require:
  • Memory: vocab_size × embedding_dim × 4 bytes (e.g., 1B users × 256 dims × 4 bytes = 1TB)
  • Training: Sparse updates make training inefficient
  • Cold start: New users/posts have no embeddings
Hash functions map IDs to fixed-size buckets:
  • Fixed memory: num_buckets × embedding_dim (e.g., 1M buckets × 256 dims × 4 bytes = 1GB)
  • Automatic coverage: All IDs (including new ones) get embeddings
  • Collision handling: Multiple hash functions reduce collision impact

Architecture

Configuration

The system uses multiple hash functions for each feature type:
phoenix/recsys_model.py
@dataclass
class HashConfig:
    """Configuration for hash-based embeddings."""

    num_user_hashes: int = 2      # 2 hash functions for users
    num_item_hashes: int = 2      # 2 hash functions for posts
    num_author_hashes: int = 2    # 2 hash functions for authors
Default: Each feature type uses 2 hash functions, providing redundancy and reducing collision impact.

Input Structure

Features are represented as hash arrays:
phoenix/recsys_model.py
class RecsysBatch(NamedTuple):
    """Input batch for the recommendation model."""

    user_hashes: jax.typing.ArrayLike              # [B, num_user_hashes]
    history_post_hashes: jax.typing.ArrayLike       # [B, S, num_item_hashes]
    history_author_hashes: jax.typing.ArrayLike     # [B, S, num_author_hashes]
    history_actions: jax.typing.ArrayLike           # [B, S, num_actions]
    history_product_surface: jax.typing.ArrayLike   # [B, S]
    candidate_post_hashes: jax.typing.ArrayLike     # [B, C, num_item_hashes]
    candidate_author_hashes: jax.typing.ArrayLike   # [B, C, num_author_hashes]
    candidate_product_surface: jax.typing.ArrayLike # [B, C]
Hash value 0 is reserved for padding. Valid hash values start from 1.

Hash Embedding Workflow

Step 1: Hash Lookup

Each hash value is used to look up an embedding vector:
# Example: User with ID 123456789
user_id = 123456789

# Apply 2 hash functions
hash_1 = hash_function_1(user_id) % num_buckets  # e.g., 42857
hash_2 = hash_function_2(user_id) % num_buckets  # e.g., 91234

# Look up embeddings
user_hashes = [hash_1, hash_2]  # [2]
user_embeddings = embedding_table[user_hashes]  # [2, D]

Step 2: Embedding Combination

Multiple hash embeddings are concatenated and projected to the model dimension:
phoenix/recsys_model.py
def block_user_reduce(
    user_hashes: jnp.ndarray,              # [B, num_user_hashes]
    user_embeddings: jnp.ndarray,          # [B, num_user_hashes, D]
    num_user_hashes: int,
    emb_size: int,
    embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
    """Combine multiple user hash embeddings into a single user representation."""
    B = user_embeddings.shape[0]
    D = emb_size

    # Flatten hash embeddings: [B, num_user_hashes, D] -> [B, 1, num_user_hashes * D]
    user_embedding = user_embeddings.reshape((B, 1, num_user_hashes * D))

    # Learned projection matrix: [num_user_hashes * D, D]
    embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
    proj_mat_1 = hk.get_parameter(
        "proj_mat_1",
        [num_user_hashes * D, D],
        dtype=jnp.float32,
        init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
    )

    # Project to model dimension: [B, 1, num_user_hashes * D] @ [num_user_hashes * D, D] -> [B, 1, D]
    user_embedding = jnp.dot(user_embedding.astype(proj_mat_1.dtype), proj_mat_1).astype(
        user_embeddings.dtype
    )

    # Hash 0 is reserved for padding
    user_padding_mask = (user_hashes[:, 0] != 0).reshape(B, 1).astype(jnp.bool_)

    return user_embedding, user_padding_mask
1

Flatten

Concatenate all hash embeddings: [B, 2, 256] → [B, 512]
2

Project

Apply learned linear projection: [B, 512] @ [512, 256] → [B, 256]
3

Mask

Create padding mask from first hash value (0 = padding)

Step 3: Feature Combination

For history and candidates, multiple feature types (post, author, actions, product surface) are combined:
phoenix/recsys_model.py
def block_history_reduce(
    history_post_hashes: jnp.ndarray,                    # [B, S, num_item_hashes]
    history_post_embeddings: jnp.ndarray,                # [B, S, num_item_hashes, D]
    history_author_embeddings: jnp.ndarray,              # [B, S, num_author_hashes, D]
    history_product_surface_embeddings: jnp.ndarray,     # [B, S, D]
    history_actions_embeddings: jnp.ndarray,             # [B, S, D]
    num_item_hashes: int,
    num_author_hashes: int,
    embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
    """Combine history embeddings (post, author, actions, product_surface) into sequence."""
    B, S, _, D = history_post_embeddings.shape

    # Flatten hash embeddings
    history_post_embeddings_reshaped = history_post_embeddings.reshape(
        (B, S, num_item_hashes * D)
    )
    history_author_embeddings_reshaped = history_author_embeddings.reshape(
        (B, S, num_author_hashes * D)
    )

    # Concatenate all feature types: post + author + actions + product_surface
    post_author_embedding = jnp.concatenate(
        [
            history_post_embeddings_reshaped,    # [B, S, num_item_hashes * D]
            history_author_embeddings_reshaped,  # [B, S, num_author_hashes * D]
            history_actions_embeddings,          # [B, S, D]
            history_product_surface_embeddings,  # [B, S, D]
        ],
        axis=-1,
    )  # [B, S, (num_item_hashes + num_author_hashes + 2) * D]

    # Project to model dimension
    embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
    proj_mat_3 = hk.get_parameter(
        "proj_mat_3",
        [post_author_embedding.shape[-1], D],
        dtype=jnp.float32,
        init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
    )

    history_embedding = jnp.dot(
        post_author_embedding.astype(proj_mat_3.dtype), proj_mat_3
    ).astype(post_author_embedding.dtype)

    history_padding_mask = (history_post_hashes[:, :, 0] != 0).reshape(B, S)

    return history_embedding, history_padding_mask

Example: Processing a History Item

Let’s trace how a single history item is embedded:
# Input
post_id = 987654321
author_id = 123456789
action = "like"
product_surface = "home_timeline"

# Step 1: Hash to buckets
post_hashes = [hash1(post_id), hash2(post_id)]           # [2]
author_hashes = [hash1(author_id), hash2(author_id)]     # [2]

# Step 2: Look up embeddings (pre-computed)
post_embeddings = embedding_table[post_hashes]           # [2, 256]
author_embeddings = embedding_table[author_hashes]       # [2, 256]
action_embedding = action_projection[action]             # [256]
product_embedding = product_embedding_table[product_surface]  # [256]

# Step 3: Concatenate
combined = concat([
    post_embeddings.flatten(),    # [512]
    author_embeddings.flatten(),  # [512]
    action_embedding,             # [256]
    product_embedding             # [256]
])  # [1536]

# Step 4: Project to model dimension
history_item_embedding = combined @ projection_matrix  # [1536] @ [1536, 256] -> [256]

Benefits

1. Memory Efficiency

Embedding Table Size = vocabulary_size × embedding_dim

Example:
- 1B users × 256 dims × 4 bytes = 1 TB
- 10B posts × 256 dims × 4 bytes = 10 TB
- 100M authors × 256 dims × 4 bytes = 100 GB

Total: ~11 TB

2. Cold Start Handling

New users and posts automatically get embeddings:
# New user with ID never seen before
new_user_id = 999999999999

# Still gets valid embeddings through hash functions
embedding = get_embedding(new_user_id)  # Works immediately

3. Collision Robustness

Multiple hash functions provide redundancy:
  • If hash1(user_A) == hash1(user_B) (collision)
  • Likely hash2(user_A) != hash2(user_B) (different embedding)
  • Projection learns to combine both for unique representation

4. Training Efficiency

Smaller embedding tables mean:
  • Faster gradient updates
  • Better cache utilization
  • More uniform training signal (fewer sparse updates)

Trade-offs

Information Loss: Hash collisions mean different entities may share embedding components. The model must learn to disambiguate through the projection layer and context.
Tuning: The number of hash functions and bucket size represent a memory-accuracy trade-off. Phoenix defaults to 2 hash functions as a sweet spot.

Implementation in Retrieval

The retrieval model uses the same hash-based embedding approach:
phoenix/recsys_retrieval_model.py
@dataclass
class PhoenixRetrievalModel(hk.Module):
    """A two-tower retrieval model using the Phoenix transformer for user encoding."""
    
    def build_user_representation(
        self,
        batch: RecsysBatch,
        recsys_embeddings: RecsysEmbeddings,
    ) -> Tuple[jax.Array, jax.Array]:
        """Build user representation from user features and history."""
        # Same hash-based embedding reduction
        user_embeddings, user_padding_mask = block_user_reduce(
            batch.user_hashes,
            recsys_embeddings.user_embeddings,
            hash_config.num_user_hashes,
            config.emb_size,
            1.0,
        )
        # ... transformer encoding ...

Candidate Isolation

Learn how candidates are scored independently in the ranking transformer

Multi-action Prediction

See how hash embeddings feed into multi-task prediction heads

Build docs developers (and LLMs) love