Skip to main content

Overview

Instead of predicting a single “relevance” score, Phoenix predicts probabilities for multiple engagement actions simultaneously. This multi-task learning approach enables the model to capture nuanced user preferences and allows flexible weighting of different engagement types during ranking.

Motivation

A single relevance score cannot capture the complexity of user engagement:

Problem: Single Score

  • Treats all engagement equally
  • Cannot distinguish passive (click) from active (reply) engagement
  • No signal for negative actions (block, mute)
  • Inflexible weighting at serving time

Solution: Multi-action

  • Predicts probability for each action type
  • Captures full spectrum of user behavior
  • Includes negative signals to avoid bad content
  • Flexible combination at serving time

Action Types

Phoenix predicts probabilities for 14 distinct actions, spanning positive, passive, and negative engagement:
Predictions:
├── Positive Actions (high value)
│   ├── P(favorite)         # User likes the post
│   ├── P(reply)            # User replies to the post
│   ├── P(repost)           # User reposts to their followers
│   ├── P(quote)            # User quote-tweets
│   ├── P(share)            # User shares externally
│   └── P(follow_author)    # User follows the author

├── Passive Actions (medium value)
│   ├── P(click)            # User clicks to view details
│   ├── P(profile_click)    # User clicks on author profile
│   ├── P(video_view)       # User watches video content
│   ├── P(photo_expand)     # User expands photo
│   └── P(dwell)            # User dwells on post (time spent)

└── Negative Actions (penalize)
    ├── P(not_interested)   # User marks "not interested"
    ├── P(block_author)     # User blocks the author
    ├── P(mute_author)      # User mutes the author
    └── P(report)           # User reports the post
The exact number of actions is configurable via PhoenixModelConfig.num_actions. The model architecture adapts automatically.

Architecture

Output Layer

The model uses a single unembedding matrix to project candidate embeddings to action logits:
phoenix/recsys_model.py
def _get_unembedding(self) -> jax.Array:
    """Get the unembedding matrix for decoding to logits."""
    config = self.config
    embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
    unembed_mat = hk.get_parameter(
        "unembeddings",
        [config.emb_size, config.num_actions],  # [256, 14]
        dtype=jnp.float32,
        init=embed_init,
    )
    return unembed_mat

Forward Pass

The ranking model produces action logits for each candidate:
phoenix/recsys_model.py
def __call__(
    self,
    batch: RecsysBatch,
    recsys_embeddings: RecsysEmbeddings,
) -> RecsysModelOutput:
    """Forward pass for ranking candidates.
    
    Returns:
        RecsysModelOutput containing logits for each candidate.
        Shape = [B, num_candidates, num_actions]
    """
    embeddings, padding_mask, candidate_start_offset = self.build_inputs(
        batch, recsys_embeddings
    )

    # Transformer with candidate isolation
    model_output = self.model(
        embeddings,
        padding_mask,
        candidate_start_offset=candidate_start_offset,
    )

    out_embeddings = model_output.embeddings
    out_embeddings = layer_norm(out_embeddings)

    # Extract only candidate embeddings (not user/history)
    candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]
    # Shape: [B, num_candidates, emb_size]

    # Project to action logits
    unembeddings = self._get_unembedding()
    logits = jnp.dot(candidate_embeddings.astype(unembeddings.dtype), unembeddings)
    # Shape: [B, num_candidates, num_actions]

    logits = logits.astype(self.fprop_dtype)

    return RecsysModelOutput(logits=logits)
1

Transformer Encoding

Process user + history + candidates → [B, seq_len, 256]
2

Extract Candidates

Select only candidate positions → [B, num_candidates, 256]
3

Project to Actions

Apply unembedding matrix → [B, num_candidates, 14]
4

Return Logits

Output logits for each action type (converted to probabilities during training/inference)

Training

Multi-action prediction is trained as a multi-task binary classification problem:

Label Structure

Each training example has a multi-hot label vector:
# Example: User liked and replied to a post
labels = [
    1,  # favorite ✓
    1,  # reply ✓
    0,  # repost
    0,  # quote
    0,  # click (implicitly true but not labeled)
    # ... etc
]

Loss Function

Binary cross-entropy loss for each action:
# Pseudo-code
logits = model(user, history, candidates)  # [B, C, num_actions]
probabilities = sigmoid(logits)             # [B, C, num_actions]

# Per-action binary cross-entropy
loss = 0
for action_idx in range(num_actions):
    loss += binary_cross_entropy(
        predictions=probabilities[:, :, action_idx],
        labels=labels[:, :, action_idx]
    )

# Average across actions and candidates
loss = loss / (num_actions * num_candidates)
Implicit Feedback: Some actions (like clicks) are implicit. If a user liked a post, they must have clicked it, but clicks may not be explicitly labeled.

Serving: Weighted Scoring

At serving time, action probabilities are combined using learned weights to produce a final ranking score:
# Get predictions from model
predictions = model(user, history, candidates)
probabilities = sigmoid(predictions)  # [num_candidates, num_actions]

# Weighted combination
weights = {
    'favorite': 1.0,
    'reply': 3.0,           # Replies valued higher
    'repost': 2.0,
    'quote': 2.5,
    'click': 0.1,           # Low weight for passive action
    'video_view': 0.3,
    'profile_click': 0.2,
    'photo_expand': 0.1,
    'share': 2.0,
    'dwell': 0.5,
    'follow_author': 4.0,   # Very valuable action
    'not_interested': -2.0, # Negative weight
    'block_author': -10.0,  # Strong negative signal
    'mute_author': -5.0,
    'report': -20.0,        # Strongest negative signal
}

# Compute final score for each candidate
final_scores = []
for candidate_idx in range(num_candidates):
    score = sum(
        weights[action] * probabilities[candidate_idx, action_idx]
        for action_idx, action in enumerate(weights.keys())
    )
    final_scores.append(score)

# Rank by final score
ranked_candidates = sorted(candidates, key=lambda c: final_scores[c], reverse=True)
Weight Tuning: Weights can be adjusted without retraining the model, enabling rapid experimentation and A/B testing of ranking strategies.

Benefits

1. Nuanced Understanding

The model learns subtle differences in user intent:
# User A: High P(like), Low P(reply)
# → Passive consumer, likes content but doesn't engage deeply

# User B: High P(reply), Medium P(like)
# → Active engager, participates in conversations

# User C: High P(block_author), Low everything else
# → Content is likely spam or offensive

2. Flexible Optimization

Different product goals can use different weights:
# Optimize for high-engagement actions
weights = {
    'reply': 5.0,
    'repost': 3.0,
    'favorite': 2.0,
    'click': 0.5,
}

3. Better Calibration

Predicting specific actions provides better-calibrated probabilities:
  • P(like) = 0.3 is interpretable: “30% chance user will like this”
  • Single relevance score lacks this interpretation

4. Multi-task Learning

Sharing representations across tasks improves generalization:
  • Rare actions (e.g., report) benefit from signal of common actions (e.g., click)
  • Model learns robust features that transfer across action types

Action Embeddings for History

In the input, historical actions are encoded using a learned projection:
phoenix/recsys_model.py
def _get_action_embeddings(
    self,
    actions: jax.Array,  # [B, S, num_actions] multi-hot
) -> jax.Array:
    """Convert multi-hot action vectors to embeddings.

    Uses a learned projection matrix to map the signed action vector
    to the embedding dimension. This works for any number of actions.
    """
    config = self.config
    _, _, num_actions = actions.shape
    D = config.emb_size

    embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
    action_projection = hk.get_parameter(
        "action_projection",
        [num_actions, D],       # [14, 256]
        dtype=jnp.float32,
        init=embed_init,
    )

    # Convert binary {0,1} to signed {-1,+1}
    actions_signed = (2 * actions - 1).astype(jnp.float32)

    # Project: [B, S, num_actions] @ [num_actions, D] -> [B, S, D]
    action_emb = jnp.dot(actions_signed.astype(action_projection.dtype), action_projection)

    # Mask out invalid positions (where no action occurred)
    valid_mask = jnp.any(actions, axis=-1, keepdims=True)
    action_emb = action_emb * valid_mask

    return action_emb.astype(self.fprop_dtype)
Signed Encoding: Actions are converted from to before projection. This allows the model to learn both positive and negative directions for each action type.

Example: Scoring a Candidate

Let’s trace how a candidate post is scored:
# Input
user_id = 123
history = [
    {post_id: 1, author_id: 10, actions: ['like', 'click']},
    {post_id: 2, author_id: 11, actions: ['reply', 'like', 'click']},
    {post_id: 3, author_id: 12, actions: ['repost', 'like']},
]
candidate = {post_id: 999, author_id: 50}

# Step 1: Model prediction
logits = model(user_id, history, candidate)  # [14]
probabilities = sigmoid(logits)              # [14]

# Step 2: Example probabilities
P = {
    'favorite': 0.35,
    'reply': 0.05,
    'repost': 0.02,
    'quote': 0.01,
    'click': 0.60,
    'profile_click': 0.10,
    'video_view': 0.00,  # Not a video post
    'photo_expand': 0.20,
    'share': 0.01,
    'dwell': 0.40,
    'follow_author': 0.03,
    'not_interested': 0.01,
    'block_author': 0.001,
    'mute_author': 0.001,
    'report': 0.0001,
}

# Step 3: Apply weights
weights = {
    'favorite': 1.0, 'reply': 3.0, 'repost': 2.0, 'quote': 2.5,
    'click': 0.1, 'profile_click': 0.2, 'video_view': 0.3,
    'photo_expand': 0.1, 'share': 2.0, 'dwell': 0.5, 'follow_author': 4.0,
    'not_interested': -2.0, 'block_author': -10.0,
    'mute_author': -5.0, 'report': -20.0,
}

final_score = sum(weights[action] * P[action] for action in weights.keys())
# = 1.0*0.35 + 3.0*0.05 + 2.0*0.02 + ... + (-20.0)*0.0001
# = 0.35 + 0.15 + 0.04 + ... - 0.002
# ≈ 0.83

Comparison with Single-Score Models

Advantages:
  • Captures nuanced user behavior
  • Flexible weighting at serving time
  • Better calibration for individual actions
  • Negative signals to avoid bad content
  • Multi-task learning improves generalization
Trade-offs:
  • More complex to tune (multiple weights)
  • Higher inference cost (14 predictions vs 1)
  • Requires multi-hot labels in training data

Candidate Isolation

Learn how candidates are scored independently for consistent predictions

Design Decisions

Understand the rationale behind multi-action prediction and other key choices

Build docs developers (and LLMs) love