dpo_loss
def dpo_loss(
chosen_logprobs: Tensor,
rejected_logprobs: Tensor,
beta: float = 0.1,
) -> Tensor
Compute Direct Preference Optimization (DPO) loss for preference-based alignment.
Mathematical definition
The DPO loss implements the objective from Rafailov et al. (2023):
LDPO=−logσ(β⋅Δlogπ)
where:
Δlogπ=logπθ(y+∣x)−logπθ(y−∣x)
Here:
- πθ is the policy (language model) being trained
- y+ is the preferred/chosen response
- y− is the rejected response
- x is the input prompt
- β>0 is a temperature parameter controlling the strength of the preference
- σ is the sigmoid function
Full DPO objective
The complete DPO objective from the paper includes a reference policy πref:
LDPO=−logσ(β⋅(logπref(y+∣x)πθ(y+∣x)−logπref(y−∣x)πθ(y−∣x)))
This implementation uses a simplified single-model variant that omits the reference policy subtraction while keeping the β temperature parameter. This still encourages higher log-probability for preferred responses.
Parameters
Log probabilities of chosen/preferred responses. Shape: (batch_size,) or (batch_size, seq_len).These should be computed as:chosen_logprobs = log_probs.gather(-1, chosen_ids.unsqueeze(-1)).squeeze(-1).sum(-1)
Log probabilities of rejected responses. Must have the same shape as chosen_logprobs.
Temperature parameter controlling preference strength. Higher values increase the penalty for preferring rejected responses.Typical values:
0.1 - Default, balanced preference strength
0.5 - Stronger preference signal
0.01 - Weaker preference signal
Returns
Scalar loss tensor suitable for backpropagation via loss.backward(). The loss is the mean over the batch.
Raises
ValueError - If chosen_logprobs and rejected_logprobs have different shapes
ValueError - If beta <= 0
Properties
Preconditions:
chosen_logprobs.shape == rejected_logprobs.shape
beta > 0
Postconditions:
- Returns a scalar tensor suitable for backpropagation
- No gradients flow through
beta (scalar hyperparameter)
Complexity:
- Time: O(N) where N is the batch size
- Space: O(N) for intermediate tensors
Usage
import torch
from modern_llm.alignment.dpo_loss import dpo_loss
# Example: Compute DPO loss for a batch of preferences
batch_size = 4
# Log probabilities from model forward pass
chosen_logprobs = torch.tensor([-2.1, -1.8, -2.5, -1.9]) # Higher is better
rejected_logprobs = torch.tensor([-3.2, -2.9, -3.1, -2.8]) # Lower quality
# Compute loss
loss = dpo_loss(
chosen_logprobs=chosen_logprobs,
rejected_logprobs=rejected_logprobs,
beta=0.1,
)
print(f"DPO Loss: {loss.item():.4f}")
# Backpropagate
loss.backward()
Complete training example
import torch
import torch.nn.functional as F
from modern_llm.alignment.dpo_loss import dpo_loss
from modern_llm.models.transformer import ModernDecoderLM
def compute_sequence_logprobs(
model: ModernDecoderLM,
input_ids: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""Compute log probabilities for a sequence."""
outputs = model(input_ids)
logits = outputs["logits"] if isinstance(outputs, dict) else outputs
# Get log probabilities
log_probs = F.log_softmax(logits, dim=-1)
# Gather log probs for actual tokens
batch_size, seq_len = labels.shape
token_log_probs = log_probs.gather(
dim=-1,
index=labels.unsqueeze(-1),
).squeeze(-1)
# Sum over sequence (average could also be used)
return token_log_probs.sum(dim=-1)
# Training loop
model = ModernDecoderLM(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
for batch in dataloader:
# Batch contains prompts, chosen responses, and rejected responses
chosen_ids = batch["chosen_input_ids"]
chosen_labels = batch["chosen_labels"]
rejected_ids = batch["rejected_input_ids"]
rejected_labels = batch["rejected_labels"]
# Compute log probabilities
chosen_logprobs = compute_sequence_logprobs(model, chosen_ids, chosen_labels)
rejected_logprobs = compute_sequence_logprobs(model, rejected_ids, rejected_labels)
# Compute DPO loss
loss = dpo_loss(
chosen_logprobs=chosen_logprobs,
rejected_logprobs=rejected_logprobs,
beta=0.1,
)
# Optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Loss: {loss.item():.4f}")
Gradient behavior
The DPO loss gradient encourages the model to:
- Increase log probability of chosen responses
- Decrease log probability of rejected responses
- The gradient magnitude is controlled by β and the current preference margin
When the model already strongly prefers chosen over rejected responses (large positive margin), the gradient becomes small, naturally stopping optimization for that example.
References
- Rafailov, R., Sharma, A., Mitchell, E., Ermon, S., Manning, C. D., & Finn, C. (2023). Direct Preference Optimization: Your Language Model is Secretly a Reward Model. arXiv:2305.18290.