Skip to main content

Overview

The PhoenixModel is a transformer-based recommendation model for ranking candidate posts. It processes user features, interaction history, and candidate posts to predict engagement actions.

Classes

PhoenixModel

A transformer-based recommendation model for ranking candidates.
model
Transformer
required
The underlying Grok transformer model
config
PhoenixModelConfig
required
Model configuration including embedding size, sequence lengths, and hash config
fprop_dtype
Any
default:"jnp.bfloat16"
Forward propagation data type for computation
name
Optional[str]
Optional name for the model

PhoenixModelConfig

Configuration dataclass for the recommendation system model.
model
TransformerConfig
required
Transformer architecture configuration
emb_size
int
required
Embedding dimension size
num_actions
int
required
Number of action types to predict (e.g., like, repost, reply)
history_seq_len
int
default:"128"
Maximum length of user history sequence
candidate_seq_len
int
default:"32"
Maximum number of candidate posts to rank
name
Optional[str]
Optional model name
fprop_dtype
Any
default:"jnp.bfloat16"
Forward propagation data type
hash_config
HashConfig
Hash configuration for multi-hash embeddings
product_surface_vocab_size
int
default:"16"
Size of product surface vocabulary (e.g., ForYou, Following, Search)

HashConfig

Configuration for hash-based embeddings.
num_user_hashes
int
default:"2"
Number of hash functions for user embeddings
num_item_hashes
int
default:"2"
Number of hash functions for post embeddings
num_author_hashes
int
default:"2"
Number of hash functions for author embeddings

Named Tuples

RecsysBatch

Input batch for the recommendation model containing feature data (hashes, actions, product surfaces) but NOT embeddings.
user_hashes
jax.typing.ArrayLike
User hash values [B, num_user_hashes]
history_post_hashes
jax.typing.ArrayLike
Post hash values for history [B, S, num_item_hashes]
history_author_hashes
jax.typing.ArrayLike
Author hash values for history [B, S, num_author_hashes]
history_actions
jax.typing.ArrayLike
Multi-hot action vectors for history [B, S, num_actions]
history_product_surface
jax.typing.ArrayLike
Product surface indices for history [B, S]
candidate_post_hashes
jax.typing.ArrayLike
Post hash values for candidates [B, C, num_item_hashes]
candidate_author_hashes
jax.typing.ArrayLike
Author hash values for candidates [B, C, num_author_hashes]
candidate_product_surface
jax.typing.ArrayLike
Product surface indices for candidates [B, C]

RecsysEmbeddings

Container for pre-looked-up embeddings from the embedding tables.
user_embeddings
jax.typing.ArrayLike
Pre-looked-up user embeddings [B, num_user_hashes, D]
history_post_embeddings
jax.typing.ArrayLike
Pre-looked-up post embeddings for history [B, S, num_item_hashes, D]
candidate_post_embeddings
jax.typing.ArrayLike
Pre-looked-up post embeddings for candidates [B, C, num_item_hashes, D]
history_author_embeddings
jax.typing.ArrayLike
Pre-looked-up author embeddings for history [B, S, num_author_hashes, D]
candidate_author_embeddings
jax.typing.ArrayLike
Pre-looked-up author embeddings for candidates [B, C, num_author_hashes, D]

RecsysModelOutput

Output of the recommendation model.
logits
jax.Array
Predicted logits for each candidate and action [B, num_candidates, num_actions]

Methods

__call__

def __call__(
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
) -> RecsysModelOutput
Forward pass for ranking candidates.
batch
RecsysBatch
required
Batch containing hashes, actions, and product surfaces
recsys_embeddings
RecsysEmbeddings
required
Pre-looked-up embeddings from embedding tables
output
RecsysModelOutput
Model output containing logits for each candidate. Shape: [B, num_candidates, num_actions]

build_inputs

def build_inputs(
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array, int]
Build input embeddings from batch and pre-looked-up embeddings.
batch
RecsysBatch
required
Batch containing hashes, actions, and product surfaces
recsys_embeddings
RecsysEmbeddings
required
Pre-looked-up embeddings from embedding tables
embeddings
jax.Array
Combined embeddings [B, 1 + history_len + num_candidates, D]
padding_mask
jax.Array
Padding mask [B, 1 + history_len + num_candidates]
candidate_start_offset
int
Position where candidates start in the sequence

Utility Functions

block_user_reduce

def block_user_reduce(
    user_hashes: jnp.ndarray,
    user_embeddings: jnp.ndarray,
    num_user_hashes: int,
    emb_size: int,
    embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]
Combine multiple user hash embeddings into a single user representation.
user_hashes
jnp.ndarray
required
Hash values [B, num_user_hashes], where 0 = invalid/padding
user_embeddings
jnp.ndarray
required
Looked-up embeddings [B, num_user_hashes, D]
num_user_hashes
int
required
Number of hash functions used
emb_size
int
required
Embedding dimension D
embed_init_scale
float
default:"1.0"
Initialization scale for projection matrix
user_embedding
jax.Array
Combined user embedding [B, 1, D]
user_padding_mask
jax.Array
True where user is valid [B, 1]

block_history_reduce

def block_history_reduce(
    history_post_hashes: jnp.ndarray,
    history_post_embeddings: jnp.ndarray,
    history_author_embeddings: jnp.ndarray,
    history_product_surface_embeddings: jnp.ndarray,
    history_actions_embeddings: jnp.ndarray,
    num_item_hashes: int,
    num_author_hashes: int,
    embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]
Combine history embeddings (post, author, actions, product_surface) into a sequence.
history_post_hashes
jnp.ndarray
required
Post hash values [B, S, num_item_hashes]
history_post_embeddings
jnp.ndarray
required
Post embeddings [B, S, num_item_hashes, D]
history_author_embeddings
jnp.ndarray
required
Author embeddings [B, S, num_author_hashes, D]
history_product_surface_embeddings
jnp.ndarray
required
Product surface embeddings [B, S, D]
history_actions_embeddings
jnp.ndarray
required
Action embeddings [B, S, D]
num_item_hashes
int
required
Number of hash functions for items
num_author_hashes
int
required
Number of hash functions for authors
embed_init_scale
float
default:"1.0"
Initialization scale for projection
history_embeddings
jax.Array
Combined history embeddings [B, S, D]
history_padding_mask
jax.Array
Valid history positions [B, S]

block_candidate_reduce

def block_candidate_reduce(
    candidate_post_hashes: jnp.ndarray,
    candidate_post_embeddings: jnp.ndarray,
    candidate_author_embeddings: jnp.ndarray,
    candidate_product_surface_embeddings: jnp.ndarray,
    num_item_hashes: int,
    num_author_hashes: int,
    embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]
Combine candidate embeddings (post, author, product_surface) into a sequence.
candidate_post_hashes
jnp.ndarray
required
Post hash values [B, C, num_item_hashes]
candidate_post_embeddings
jnp.ndarray
required
Post embeddings [B, C, num_item_hashes, D]
candidate_author_embeddings
jnp.ndarray
required
Author embeddings [B, C, num_author_hashes, D]
candidate_product_surface_embeddings
jnp.ndarray
required
Product surface embeddings [B, C, D]
num_item_hashes
int
required
Number of hash functions for items
num_author_hashes
int
required
Number of hash functions for authors
embed_init_scale
float
default:"1.0"
Initialization scale for projection
candidate_embeddings
jax.Array
Combined candidate embeddings [B, C, D]
candidate_padding_mask
jax.Array
Valid candidate positions [B, C]

Notes

  • Hash value 0 is reserved for padding in all hash-based inputs
  • The model uses multi-hash embeddings to reduce collisions in the embedding space
  • Product surfaces represent different UI contexts (e.g., ForYou, Following, Search)
  • Actions are represented as multi-hot vectors to capture multiple engagement signals

Build docs developers (and LLMs) love