Overview
The FocalLoss class implements the Focal Loss function, which is particularly effective for training models on imbalanced datasets. It down-weights easy examples and focuses training on hard negatives, making it ideal for AD conversion prediction where converters are a minority class.
Focal Loss modifies the standard cross-entropy loss by adding a modulating factor that reduces the loss contribution from easy examples and focuses on hard, misclassified examples.
Class Definition
FocalLoss
class FocalLoss(nn.Module):
def __init__(
self,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = 'mean',
label_smoothing: float = 0.0
)
PyTorch loss module for Focal Loss with optional label smoothing.
Weight for the positive class (minority class). Values closer to 1.0 give more weight to converters. Typical range: 0.25-0.90.
Focusing parameter that controls how much to down-weight easy examples. Higher values focus more on hard examples. Typical range: 2.0-5.0.
Specifies the reduction to apply to the output. Options: 'mean', 'sum', or 'none'.
Label smoothing factor to prevent overconfidence. Typical range: 0.0-0.1.
forward
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor
) -> torch.Tensor
Computes the focal loss between predictions and targets.
Predicted logits from the model with shape (batch_size, num_classes).
Ground truth class indices with shape (batch_size,).
Computed focal loss value (scalar if reduction=‘mean’ or ‘sum’).
Usage Examples
Basic Usage
from FocalLoss import FocalLoss
import torch
# Initialize focal loss
criterion = FocalLoss(
alpha=0.90, # High weight for minority class
gamma=3.0, # Strong focus on hard examples
label_smoothing=0.05
)
# Model predictions (logits)
logits = model(data)
# Compute loss
loss = criterion(logits, labels)
loss.backward()
Training Loop Integration
# Configuration from main.py
criterion = FocalLoss(
alpha=opt.focal_alpha, # e.g., 0.90
gamma=opt.focal_gamma, # e.g., 3.0
label_smoothing=opt.label_smoothing # e.g., 0.05
)
for epoch in range(num_epochs):
for batch in train_loader:
optimizer.zero_grad()
# Forward pass
logits = classifier(graph_seq, None, lengths)
# Compute focal loss
loss = criterion(logits, labels)
# Backward pass
loss.backward()
optimizer.step()
Parameter Tuning Guidelines
# For heavily imbalanced datasets (10:1 or more)
criterion_high_imbalance = FocalLoss(
alpha=0.90, # Very high weight on minority class
gamma=3.0, # Strong focusing
label_smoothing=0.05
)
# For moderately imbalanced datasets (3:1 to 10:1)
criterion_moderate = FocalLoss(
alpha=0.75,
gamma=2.0,
label_smoothing=0.0
)
# For nearly balanced datasets
criterion_balanced = FocalLoss(
alpha=0.5, # Equal weight to both classes
gamma=2.0,
label_smoothing=0.0
)
Implementation Details
Loss Computation
The focal loss is computed as:
FL(pt) = -αt * (1 - pt)^γ * log(pt)
Where:
pt is the model’s estimated probability for the correct class
αt is the class-dependent weighting factor
γ is the focusing parameter
Label Smoothing
When label_smoothing > 0, hard labels are smoothed:
target_smooth = target * (1 - ε) + ε / num_classes
This prevents the model from becoming overconfident.
Configuration in main.py
parser.add_argument('--focal_alpha', type=float, default=0.90,
help='focal loss alpha (weight for minority class)')
parser.add_argument('--focal_gamma', type=float, default=3.0,
help='focal loss gamma (focusing parameter)')
parser.add_argument('--label_smoothing', type=float, default=0.05,
help='label smoothing factor')
Best Practices
For AD conversion prediction, start with alpha=0.90 and gamma=3.0 to heavily focus on the minority converter class.
Very high gamma values (>5.0) may cause training instability. Monitor validation performance carefully.
Label smoothing helps prevent overfitting but may slightly reduce peak performance. Start with 0.0 and add 0.05 if you observe overconfidence.