Skip to main content

dpo_loss

def dpo_loss(
    chosen_logprobs: Tensor,
    rejected_logprobs: Tensor,
    beta: float = 0.1,
) -> Tensor
Compute Direct Preference Optimization (DPO) loss for preference-based alignment.

Mathematical definition

The DPO loss implements the objective from Rafailov et al. (2023): LDPO=logσ(βΔlogπ)\mathcal{L}_{\text{DPO}} = -\log \sigma\left(\beta \cdot \Delta \log \pi\right) where: Δlogπ=logπθ(y+x)logπθ(yx)\Delta \log \pi = \log \pi_\theta(y^+ | x) - \log \pi_\theta(y^- | x) Here:
  • πθ\pi_\theta is the policy (language model) being trained
  • y+y^+ is the preferred/chosen response
  • yy^- is the rejected response
  • xx is the input prompt
  • β>0\beta > 0 is a temperature parameter controlling the strength of the preference
  • σ\sigma is the sigmoid function

Full DPO objective

The complete DPO objective from the paper includes a reference policy πref\pi_{\text{ref}}: LDPO=logσ(β(logπθ(y+x)πref(y+x)logπθ(yx)πref(yx)))\mathcal{L}_{\text{DPO}} = -\log \sigma\left(\beta \cdot \left(\log \frac{\pi_\theta(y^+ | x)}{\pi_{\text{ref}}(y^+ | x)} - \log \frac{\pi_\theta(y^- | x)}{\pi_{\text{ref}}(y^- | x)}\right)\right) This implementation uses a simplified single-model variant that omits the reference policy subtraction while keeping the β\beta temperature parameter. This still encourages higher log-probability for preferred responses.

Parameters

chosen_logprobs
Tensor
required
Log probabilities of chosen/preferred responses. Shape: (batch_size,) or (batch_size, seq_len).These should be computed as:
chosen_logprobs = log_probs.gather(-1, chosen_ids.unsqueeze(-1)).squeeze(-1).sum(-1)
rejected_logprobs
Tensor
required
Log probabilities of rejected responses. Must have the same shape as chosen_logprobs.
beta
float
default:"0.1"
Temperature parameter controlling preference strength. Higher values increase the penalty for preferring rejected responses.Typical values:
  • 0.1 - Default, balanced preference strength
  • 0.5 - Stronger preference signal
  • 0.01 - Weaker preference signal

Returns

loss
Tensor
Scalar loss tensor suitable for backpropagation via loss.backward(). The loss is the mean over the batch.

Raises

  • ValueError - If chosen_logprobs and rejected_logprobs have different shapes
  • ValueError - If beta <= 0

Properties

Preconditions:
  • chosen_logprobs.shape == rejected_logprobs.shape
  • beta > 0
Postconditions:
  • Returns a scalar tensor suitable for backpropagation
  • No gradients flow through beta (scalar hyperparameter)
Complexity:
  • Time: O(N) where N is the batch size
  • Space: O(N) for intermediate tensors

Usage

import torch
from modern_llm.alignment.dpo_loss import dpo_loss

# Example: Compute DPO loss for a batch of preferences
batch_size = 4

# Log probabilities from model forward pass
chosen_logprobs = torch.tensor([-2.1, -1.8, -2.5, -1.9])  # Higher is better
rejected_logprobs = torch.tensor([-3.2, -2.9, -3.1, -2.8])  # Lower quality

# Compute loss
loss = dpo_loss(
    chosen_logprobs=chosen_logprobs,
    rejected_logprobs=rejected_logprobs,
    beta=0.1,
)

print(f"DPO Loss: {loss.item():.4f}")

# Backpropagate
loss.backward()

Complete training example

import torch
import torch.nn.functional as F
from modern_llm.alignment.dpo_loss import dpo_loss
from modern_llm.models.transformer import ModernDecoderLM

def compute_sequence_logprobs(
    model: ModernDecoderLM,
    input_ids: torch.Tensor,
    labels: torch.Tensor,
) -> torch.Tensor:
    """Compute log probabilities for a sequence."""
    outputs = model(input_ids)
    logits = outputs["logits"] if isinstance(outputs, dict) else outputs
    
    # Get log probabilities
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Gather log probs for actual tokens
    batch_size, seq_len = labels.shape
    token_log_probs = log_probs.gather(
        dim=-1,
        index=labels.unsqueeze(-1),
    ).squeeze(-1)
    
    # Sum over sequence (average could also be used)
    return token_log_probs.sum(dim=-1)

# Training loop
model = ModernDecoderLM(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

for batch in dataloader:
    # Batch contains prompts, chosen responses, and rejected responses
    chosen_ids = batch["chosen_input_ids"]
    chosen_labels = batch["chosen_labels"]
    rejected_ids = batch["rejected_input_ids"]
    rejected_labels = batch["rejected_labels"]
    
    # Compute log probabilities
    chosen_logprobs = compute_sequence_logprobs(model, chosen_ids, chosen_labels)
    rejected_logprobs = compute_sequence_logprobs(model, rejected_ids, rejected_labels)
    
    # Compute DPO loss
    loss = dpo_loss(
        chosen_logprobs=chosen_logprobs,
        rejected_logprobs=rejected_logprobs,
        beta=0.1,
    )
    
    # Optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f"Loss: {loss.item():.4f}")

Gradient behavior

The DPO loss gradient encourages the model to:
  1. Increase log probability of chosen responses
  2. Decrease log probability of rejected responses
  3. The gradient magnitude is controlled by β\beta and the current preference margin
When the model already strongly prefers chosen over rejected responses (large positive margin), the gradient becomes small, naturally stopping optimization for that example.

References

  • Rafailov, R., Sharma, A., Mitchell, E., Ermon, S., Manning, C. D., & Finn, C. (2023). Direct Preference Optimization: Your Language Model is Secretly a Reward Model. arXiv:2305.18290.

Build docs developers (and LLMs) love