Skip to main content

Overview

The STGNN project uses Focal Loss to address class imbalance in Alzheimer’s disease progression prediction. Focal Loss reduces the loss contribution from easy examples and focuses training on hard, misclassified examples. This is particularly important in medical datasets where:
  • Cognitively Normal (CN) samples often outnumber Alzheimer’s Disease (AD) samples
  • Model performance on the minority class (AD) is critical
  • Standard cross-entropy can lead to models biased toward the majority class

FocalLoss Class

Class Signature

class FocalLoss(nn.Module):
    def __init__(self, alpha: float = 0.25, gamma: float = 2.0, 
                 reduction: str = 'mean', label_smoothing: float = 0.0)
From FocalLoss.py lines 5-12.

Parameters

alpha
float
Weighting factor for the positive class (minority class). Higher values increase focus on the positive class.Calculation:
  • For class 1 (positive): weight = alpha
  • For class 0 (negative): weight = 1 - alpha
Typical values:
  • 0.25: When positive class is 25% of dataset
  • 0.5: Balanced datasets
  • 0.75: When positive class is 75% of dataset
gamma
float
Focusing parameter that controls how much to down-weight easy examples.Effect:
  • gamma = 0: Equivalent to standard cross-entropy
  • gamma = 1: Moderate focusing
  • gamma = 2: Strong focusing (recommended default)
  • gamma = 5: Very strong focusing
The loss is multiplied by (1 - pt)^gamma where pt is the model’s confidence in the correct class.
reduction
str
default:"mean"
Specifies the reduction to apply to the output:
  • "mean": Returns the mean of the losses
  • "sum": Returns the sum of the losses
  • "none": Returns unreduced losses for each sample
label_smoothing
float
Label smoothing factor to prevent overconfident predictions.Range: [0.0, 1.0]
  • 0.0: No smoothing (hard labels)
  • 0.1: Mild smoothing (recommended for medical data)
  • 0.2+: Strong smoothing
Transforms hard labels into soft labels:
smoothed_label = label * (1 - smoothing) + smoothing / num_classes

Forward Method

def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Args:
        inputs: Model logits [B, num_classes]
        targets: Ground truth class indices [B]
    
    Returns:
        Focal loss value (scalar if reduction='mean' or 'sum', [B] if reduction='none')
    """

Input Parameters

inputs
torch.Tensor
Model output logits with shape [B, num_classes]. These are raw logits, not probabilities.Example:
inputs = torch.tensor([[2.5, -1.0], [0.5, 0.5]])  # [B=2, num_classes=2]
targets
torch.Tensor
Ground truth class indices with shape [B]. Values should be integers in range [0, num_classes).Example:
targets = torch.tensor([1, 0])  # [B=2]

Returns

  • reduction='mean': Scalar tensor (average loss)
  • reduction='sum': Scalar tensor (total loss)
  • reduction='none': Tensor of shape [B] (per-sample loss)

Implementation Details

Standard Focal Loss

Without label smoothing (label_smoothing=0.0), the implementation follows the standard focal loss formulation:
# From FocalLoss.py lines 23-32
ce_loss = F.cross_entropy(inputs, targets, reduction='none')

# Get probabilities for the correct class
pt = torch.exp(-ce_loss)

# Calculate alpha weights
alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)

# Compute focal loss
focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

Mathematical Formulation

FL(pt)=αt(1pt)γlog(pt)\text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) Where:
  • ptp_t = probability of the correct class
  • αt\alpha_t = class weight (alpha for positive class, 1-alpha for negative)
  • γ\gamma = focusing parameter

With Label Smoothing

When label_smoothing > 0, hard labels are converted to soft labels:
# From FocalLoss.py lines 17-21
if self.label_smoothing > 0:
    targets_one_hot = F.one_hot(targets, num_classes=inputs.size(1)).float()
    targets_one_hot = targets_one_hot * (1 - self.label_smoothing) + \
                     self.label_smoothing / inputs.size(1)
    ce_loss = -torch.sum(targets_one_hot * F.log_softmax(inputs, dim=1), dim=1)
Example:
# Original: [0, 1] (class 1)
# With smoothing=0.1: [0.05, 0.95]
# With smoothing=0.2: [0.1, 0.9]

Usage Examples

Basic Binary Classification

import torch
from FocalLoss import FocalLoss

# Initialize focal loss
criterion = FocalLoss(alpha=0.25, gamma=2.0)

# Model outputs and targets
logits = torch.tensor([[2.5, -1.0], [0.5, 0.5], [-1.0, 2.0]])  # [3, 2]
targets = torch.tensor([0, 1, 1])  # [3]

# Compute loss
loss = criterion(logits, targets)
print(f"Loss: {loss.item():.4f}")

With Label Smoothing

# Prevent overconfident predictions
criterion = FocalLoss(
    alpha=0.25,
    gamma=2.0,
    label_smoothing=0.1  # 10% smoothing
)

loss = criterion(logits, targets)
print(f"Smoothed loss: {loss.item():.4f}")

Per-Sample Loss

# Get unreduced loss for analysis
criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='none')

losses = criterion(logits, targets)  # [3]
print(f"Per-sample losses: {losses}")

# Identify hard examples
hard_examples = losses > losses.mean()
print(f"Hard example indices: {hard_examples.nonzero().squeeze()}")

Training Loop Integration

import torch.optim as optim

# Setup
model = STGNNModel()
criterion = FocalLoss(alpha=0.25, gamma=2.0)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        # Forward pass
        logits = model(batch)
        loss = criterion(logits, batch.y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Hyperparameter Tuning

Alpha Selection

Set alpha based on class distribution:
# Calculate class weights from dataset
num_positive = (targets == 1).sum().item()
num_negative = (targets == 0).sum().item()
total = num_positive + num_negative

# Alpha = proportion of positive class
alpha = num_positive / total
print(f"Recommended alpha: {alpha:.3f}")

criterion = FocalLoss(alpha=alpha, gamma=2.0)
Example scenarios:
# Balanced dataset (500 CN, 500 AD)
alpha = 0.5

# Imbalanced dataset (750 CN, 250 AD)
alpha = 0.25  # Focus more on minority class (AD)

# Severely imbalanced (900 CN, 100 AD)
alpha = 0.1  # Strong focus on minority class

Gamma Selection

Adjust gamma based on model performance:
# Start with standard value
criterion = FocalLoss(alpha=0.25, gamma=2.0)

# If model focuses too much on easy examples:
criterion = FocalLoss(alpha=0.25, gamma=3.0)  # Stronger focusing

# If model struggles with easy examples:
criterion = FocalLoss(alpha=0.25, gamma=1.0)  # Milder focusing

# Extreme imbalance:
criterion = FocalLoss(alpha=0.25, gamma=5.0)  # Very strong focusing
from itertools import product

alphas = [0.1, 0.25, 0.5, 0.75]
gammas = [0.5, 1.0, 2.0, 3.0]
label_smoothings = [0.0, 0.05, 0.1]

best_loss = float('inf')
best_params = None

for alpha, gamma, smoothing in product(alphas, gammas, label_smoothings):
    criterion = FocalLoss(alpha=alpha, gamma=gamma, label_smoothing=smoothing)
    
    # Evaluate on validation set
    val_loss = evaluate(model, val_loader, criterion)
    
    if val_loss < best_loss:
        best_loss = val_loss
        best_params = (alpha, gamma, smoothing)

print(f"Best params: alpha={best_params[0]}, gamma={best_params[1]}, "
      f"smoothing={best_params[2]}")

Comparison with Cross-Entropy

Standard Cross-Entropy

import torch.nn as nn

# Standard loss (no class balancing)
ce_criterion = nn.CrossEntropyLoss()
loss_ce = ce_criterion(logits, targets)

Weighted Cross-Entropy

# Class weights based on inverse frequency
class_weights = torch.tensor([1.0, 3.0])  # 3x weight for minority class
wce_criterion = nn.CrossEntropyLoss(weight=class_weights)
loss_wce = wce_criterion(logits, targets)

Focal Loss

# Combines class weighting + hard example mining
focal_criterion = FocalLoss(alpha=0.25, gamma=2.0)
loss_focal = focal_criterion(logits, targets)

Performance Comparison

Loss FunctionClass BalanceHard Example FocusOverconfidence
Cross-EntropyHigh
Weighted CEHigh
Focal LossLow
Focal + SmoothingVery Low

Behavior Analysis

Easy vs Hard Examples

import matplotlib.pyplot as plt
import numpy as np

def plot_focal_loss_curve(alpha=0.25, gamma=2.0):
    pt = np.linspace(0.01, 0.99, 100)  # Probability of correct class
    
    # Standard cross-entropy
    ce = -np.log(pt)
    
    # Focal loss
    focal = -alpha * (1 - pt)**gamma * np.log(pt)
    
    plt.figure(figsize=(10, 6))
    plt.plot(pt, ce, label='Cross-Entropy', linewidth=2)
    plt.plot(pt, focal, label=f'Focal (γ={gamma})', linewidth=2)
    plt.xlabel('Probability of Correct Class (pt)')
    plt.ylabel('Loss')
    plt.title('Focal Loss vs Cross-Entropy')
    plt.legend()
    plt.grid(True)
    plt.show()

plot_focal_loss_curve()
Observations:
  • When pt is high (easy example), focal loss is much lower than CE
  • When pt is low (hard example), focal loss is similar to CE
  • Effect amplifies with higher gamma values

Gamma Effect

def compare_gamma_values(alpha=0.25):
    pt = np.linspace(0.01, 0.99, 100)
    
    plt.figure(figsize=(10, 6))
    for gamma in [0, 1, 2, 5]:
        focal = -alpha * (1 - pt)**gamma * np.log(pt)
        plt.plot(pt, focal, label=f'γ={gamma}', linewidth=2)
    
    plt.xlabel('Probability of Correct Class (pt)')
    plt.ylabel('Loss')
    plt.title('Effect of Gamma Parameter')
    plt.legend()
    plt.grid(True)
    plt.show()

compare_gamma_values()

Best Practices

For Alzheimer’s Disease Classification

# Typical ADNI dataset: ~60% CN, ~40% AD
criterion = FocalLoss(
    alpha=0.4,              # Slight focus on AD (minority)
    gamma=2.0,              # Standard focusing
    reduction='mean',
    label_smoothing=0.1     # Prevent overconfidence in medical context
)

For Severely Imbalanced Data

# Very imbalanced: 90% CN, 10% AD
criterion = FocalLoss(
    alpha=0.1,              # Strong focus on AD
    gamma=3.0,              # Stronger focusing on hard examples
    reduction='mean',
    label_smoothing=0.05    # Mild smoothing
)

For Multi-stage Training

# Stage 1: Warm-up with milder focusing
criterion_warmup = FocalLoss(alpha=0.25, gamma=1.0)

# Stage 2: Standard focusing
criterion_main = FocalLoss(alpha=0.25, gamma=2.0)

# Stage 3: Fine-tuning with strong focusing
criterion_finetune = FocalLoss(alpha=0.25, gamma=3.0, label_smoothing=0.1)

Debugging Tips

Check Loss Values

# Monitor loss magnitudes
criterion = FocalLoss(alpha=0.25, gamma=2.0)

logits = model(batch)
loss = criterion(logits, targets)

if loss.item() > 10.0:
    print("Warning: Very high loss, check model outputs")
elif loss.item() < 0.01:
    print("Warning: Very low loss, possible overfitting")

Visualize Predictions

# Analyze model confidence
probs = torch.softmax(logits, dim=1)
confidence, predictions = probs.max(dim=1)

print(f"Mean confidence: {confidence.mean().item():.3f}")
print(f"Min confidence: {confidence.min().item():.3f}")
print(f"Max confidence: {confidence.max().item():.3f}")

# Check for overconfident predictions
overconfident = confidence > 0.95
print(f"Overconfident predictions: {overconfident.sum().item()} / {len(confidence)}")

Compare with Baseline

# Train two models for comparison
model_ce = train(model, nn.CrossEntropyLoss())
model_focal = train(model, FocalLoss(alpha=0.25, gamma=2.0))

# Evaluate on imbalanced test set
results_ce = evaluate(model_ce, test_loader)
results_focal = evaluate(model_focal, test_loader)

print(f"CE - Accuracy: {results_ce['acc']:.3f}, F1: {results_ce['f1']:.3f}")
print(f"Focal - Accuracy: {results_focal['acc']:.3f}, F1: {results_focal['f1']:.3f}")

See Also

Build docs developers (and LLMs) love