Skip to main content

Overview

The FocalLoss class implements the Focal Loss function, which is particularly effective for training models on imbalanced datasets. It down-weights easy examples and focuses training on hard negatives, making it ideal for AD conversion prediction where converters are a minority class. Focal Loss modifies the standard cross-entropy loss by adding a modulating factor that reduces the loss contribution from easy examples and focuses on hard, misclassified examples.

Class Definition

FocalLoss

class FocalLoss(nn.Module):
    def __init__(
        self,
        alpha: float = 0.25,
        gamma: float = 2.0,
        reduction: str = 'mean',
        label_smoothing: float = 0.0
    )
PyTorch loss module for Focal Loss with optional label smoothing.
alpha
float
default:"0.25"
Weight for the positive class (minority class). Values closer to 1.0 give more weight to converters. Typical range: 0.25-0.90.
gamma
float
default:"2.0"
Focusing parameter that controls how much to down-weight easy examples. Higher values focus more on hard examples. Typical range: 2.0-5.0.
reduction
str
default:"'mean'"
Specifies the reduction to apply to the output. Options: 'mean', 'sum', or 'none'.
label_smoothing
float
default:"0.0"
Label smoothing factor to prevent overconfidence. Typical range: 0.0-0.1.

forward

def forward(
    self,
    inputs: torch.Tensor,
    targets: torch.Tensor
) -> torch.Tensor
Computes the focal loss between predictions and targets.
inputs
torch.Tensor
Predicted logits from the model with shape (batch_size, num_classes).
targets
torch.Tensor
Ground truth class indices with shape (batch_size,).
loss
torch.Tensor
Computed focal loss value (scalar if reduction=‘mean’ or ‘sum’).

Usage Examples

Basic Usage

from FocalLoss import FocalLoss
import torch

# Initialize focal loss
criterion = FocalLoss(
    alpha=0.90,  # High weight for minority class
    gamma=3.0,   # Strong focus on hard examples
    label_smoothing=0.05
)

# Model predictions (logits)
logits = model(data)

# Compute loss
loss = criterion(logits, labels)
loss.backward()

Training Loop Integration

# Configuration from main.py
criterion = FocalLoss(
    alpha=opt.focal_alpha,      # e.g., 0.90
    gamma=opt.focal_gamma,      # e.g., 3.0
    label_smoothing=opt.label_smoothing  # e.g., 0.05
)

for epoch in range(num_epochs):
    for batch in train_loader:
        optimizer.zero_grad()
        
        # Forward pass
        logits = classifier(graph_seq, None, lengths)
        
        # Compute focal loss
        loss = criterion(logits, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()

Parameter Tuning Guidelines

# For heavily imbalanced datasets (10:1 or more)
criterion_high_imbalance = FocalLoss(
    alpha=0.90,  # Very high weight on minority class
    gamma=3.0,   # Strong focusing
    label_smoothing=0.05
)

# For moderately imbalanced datasets (3:1 to 10:1)
criterion_moderate = FocalLoss(
    alpha=0.75,
    gamma=2.0,
    label_smoothing=0.0
)

# For nearly balanced datasets
criterion_balanced = FocalLoss(
    alpha=0.5,  # Equal weight to both classes
    gamma=2.0,
    label_smoothing=0.0
)

Implementation Details

Loss Computation

The focal loss is computed as:
FL(pt) = -αt * (1 - pt)^γ * log(pt)
Where:
  • pt is the model’s estimated probability for the correct class
  • αt is the class-dependent weighting factor
  • γ is the focusing parameter

Label Smoothing

When label_smoothing > 0, hard labels are smoothed:
target_smooth = target * (1 - ε) + ε / num_classes
This prevents the model from becoming overconfident.

Configuration in main.py

parser.add_argument('--focal_alpha', type=float, default=0.90,
                   help='focal loss alpha (weight for minority class)')
parser.add_argument('--focal_gamma', type=float, default=3.0,
                   help='focal loss gamma (focusing parameter)')
parser.add_argument('--label_smoothing', type=float, default=0.05,
                   help='label smoothing factor')

Best Practices

For AD conversion prediction, start with alpha=0.90 and gamma=3.0 to heavily focus on the minority converter class.
Very high gamma values (>5.0) may cause training instability. Monitor validation performance carefully.
Label smoothing helps prevent overfitting but may slightly reduce peak performance. Start with 0.0 and add 0.05 if you observe overconfidence.

Build docs developers (and LLMs) love