Overview
The STGNN project uses Focal Loss to address class imbalance in Alzheimer’s disease progression prediction. Focal Loss reduces the loss contribution from easy examples and focuses training on hard, misclassified examples. This is particularly important in medical datasets where:- Cognitively Normal (CN) samples often outnumber Alzheimer’s Disease (AD) samples
- Model performance on the minority class (AD) is critical
- Standard cross-entropy can lead to models biased toward the majority class
FocalLoss Class
Class Signature
FocalLoss.py lines 5-12.
Parameters
Weighting factor for the positive class (minority class). Higher values increase focus on the positive class.Calculation:
- For class 1 (positive): weight =
alpha - For class 0 (negative): weight =
1 - alpha
0.25: When positive class is 25% of dataset0.5: Balanced datasets0.75: When positive class is 75% of dataset
Focusing parameter that controls how much to down-weight easy examples.Effect:
gamma = 0: Equivalent to standard cross-entropygamma = 1: Moderate focusinggamma = 2: Strong focusing (recommended default)gamma = 5: Very strong focusing
(1 - pt)^gamma where pt is the model’s confidence in the correct class.Specifies the reduction to apply to the output:
"mean": Returns the mean of the losses"sum": Returns the sum of the losses"none": Returns unreduced losses for each sample
Label smoothing factor to prevent overconfident predictions.Range:
[0.0, 1.0]0.0: No smoothing (hard labels)0.1: Mild smoothing (recommended for medical data)0.2+: Strong smoothing
Forward Method
Input Parameters
Model output logits with shape
[B, num_classes]. These are raw logits, not probabilities.Example:Ground truth class indices with shape
[B]. Values should be integers in range [0, num_classes).Example:Returns
reduction='mean': Scalar tensor (average loss)reduction='sum': Scalar tensor (total loss)reduction='none': Tensor of shape[B](per-sample loss)
Implementation Details
Standard Focal Loss
Without label smoothing (label_smoothing=0.0), the implementation follows the standard focal loss formulation:
Mathematical Formulation
Where:- = probability of the correct class
- = class weight (alpha for positive class, 1-alpha for negative)
- = focusing parameter
With Label Smoothing
Whenlabel_smoothing > 0, hard labels are converted to soft labels:
Usage Examples
Basic Binary Classification
With Label Smoothing
Per-Sample Loss
Training Loop Integration
Hyperparameter Tuning
Alpha Selection
Set alpha based on class distribution:Gamma Selection
Adjust gamma based on model performance:Grid Search
Comparison with Cross-Entropy
Standard Cross-Entropy
Weighted Cross-Entropy
Focal Loss
Performance Comparison
| Loss Function | Class Balance | Hard Example Focus | Overconfidence |
|---|---|---|---|
| Cross-Entropy | ❌ | ❌ | High |
| Weighted CE | ✅ | ❌ | High |
| Focal Loss | ✅ | ✅ | Low |
| Focal + Smoothing | ✅ | ✅ | Very Low |
Behavior Analysis
Easy vs Hard Examples
- When
ptis high (easy example), focal loss is much lower than CE - When
ptis low (hard example), focal loss is similar to CE - Effect amplifies with higher gamma values
Gamma Effect
Best Practices
For Alzheimer’s Disease Classification
For Severely Imbalanced Data
For Multi-stage Training
Debugging Tips
Check Loss Values
Visualize Predictions
Compare with Baseline
See Also
- Temporal Predictor - Model architectures that use FocalLoss
- Training Overview - Integration into training workflow
- Evaluation Metrics - Assessing model performance on imbalanced data