The Imbalance Problem
Alzheimer’s disease progression datasets are inherently imbalanced:- Stable subjects (class 0): Majority class (typically 70-85%)
- Converter subjects (class 1): Minority class (typically 15-30%)
- Predict “stable” for all subjects
- Achieve high accuracy but zero minority recall
- Fail to identify patients at risk of conversion
Class Distribution
Check your dataset distribution (main.py:311-313):
Focal Loss
Overview
Focal Loss addresses class imbalance by:- Down-weighting easy examples (high-confidence correct predictions)
- Up-weighting hard examples (low-confidence or misclassified)
- Class-based weighting (giving more importance to minority class)
Mathematical Formulation
Implemented inFocalLoss.py:14-38:
Alpha Parameter
Controls class weighting (--focal_alpha):
- Default (0.90)
- Moderate (0.75)
- Balanced (0.50)
- Converters (class 1): 0.90
- Stable (class 0): 0.10
FocalLoss.py:29):
Gamma Parameter
Controls focusing strength (--focal_gamma):
- No Focusing (0)
- Mild Focusing (1)
- Standard Focusing (2)
- Strong Focusing (3+)
FocalLoss.py:26-32):
Label Smoothing
Prevents overconfident predictions (--label_smoothing):
- Original: [1, 0] (hard label)
- Smoothed (0.05): [0.95, 0.05] (soft label)
- Smoothed (0.10): [0.90, 0.10] (softer label)
- Reduces overfitting
- Improves calibration
- Helps with noisy labels
FocalLoss.py:17-21
Minority Class Forcing
Overview
An additional loss term applied during early training to encourage minority class predictions (main.py:465-482):
Forcing Schedule
Forcing weight decays over epochs:| Epoch | Forcing Weight | Effect |
|---|---|---|
| 1 | 0.100 | Strong forcing |
| 5 | 0.075 | Moderate forcing |
| 10 | 0.050 | Mild forcing |
| 15 | 0.025 | Weak forcing |
| 20 | 0.000 | No forcing |
| 21+ | 0.000 | No forcing |
--minority_focus_epochs 20 (default)
Total Loss
Combination of Focal Loss and forcing loss (main.py:524-529):
The forcing loss is applied IN ADDITION to Focal Loss, not as a replacement. This provides extra emphasis on minority class during the critical early training phase.
Monitoring Class Predictions
Training Monitoring
Track prediction distribution during training (main.py:510-549):
- All predictions are class 0: Model collapsed to majority class
- All predictions are class 1: Forcing is too strong
- Balanced predictions: Good, but check if they’re accurate
Validation Monitoring
Evaluate with detailed metrics (main.py:319-380):
main.py:565-567):
Evaluation Metrics
Balanced Accuracy
Average of per-class recalls (main.py:358-359):
- Standard accuracy can be misleading with imbalance
- A model predicting all “stable” gets 80% accuracy but 50% balanced accuracy
Minority F1 Score
F1 score for converter class (main.py:356):
- F1 = 0.0: Model never predicts converters
- F1 = 0.5: Reasonable performance
- F1 = 0.7+: Good minority class detection
AUC-ROC
Area under ROC curve (main.py:362):
- Threshold-independent metric
- Considers full probability distribution
- Good for imbalanced datasets
Recommended Configurations
Severe Imbalance (< 20% minority)
Moderate Imbalance (20-35% minority)
Mild Imbalance (35-45% minority)
No Forcing (Focal Loss Only)
Gradient Clipping
Prevents exploding gradients during minority forcing (main.py:533-536):
Gradient clipping is particularly important when using minority forcing, as the additional loss term can cause large gradient spikes early in training.
Troubleshooting
Model Predicts Only Majority Class
Symptoms:- All predictions are class 0
- High accuracy but zero minority recall
unique_preds = 1
- Increase
focal_alpha(e.g., 0.95) - Increase
focal_gamma(e.g., 4.0) - Extend
minority_focus_epochs(e.g., 30) - Reduce learning rate (
--lr 0.0005) - Check if minority class has enough samples (< 10% is very hard)
Model Predicts Only Minority Class
Symptoms:- All predictions are class 1
- Low accuracy
- Forced too aggressively
- Decrease
focal_alpha(e.g., 0.75) - Decrease
focal_gamma(e.g., 1.0) - Reduce
minority_focus_epochs(e.g., 10) - Lower forcing weight in code (change 0.1 to 0.05)
Unstable Training
Symptoms:- Predictions oscillate between all-0 and all-1
- Loss spikes
- NaN gradients
- Add/increase label smoothing (
--label_smoothing 0.1) - Reduce learning rate (
--lr 0.0005) - Check gradient clipping is enabled
- Reduce batch size for more stable gradients
- Lower
focal_gamma(extreme values like 5+ can destabilize)
Low Minority F1 Despite Predictions
Symptoms:- Model predicts some converters
- But minority F1 < 0.3
- Predictions are often wrong
- Model may be underfitting:
- Increase model capacity (
--gnn_hidden_dim 512 --lstm_hidden_dim 128) - Train longer (
--n_epochs 150)
- Increase model capacity (
- Data quality issues:
- Check for label noise
- Verify converter definitions
- Reduce label smoothing (may be over-regularizing)
Class Weights Alternative
If Focal Loss is not working, try standard weighted cross-entropy:Best Practices
- Start with defaults: The default Focal Loss configuration works well for most AD datasets
- Monitor both classes: Don’t just look at accuracy - track minority recall and F1
- Gradual adjustments: Change one parameter at a time to understand effects
- Check validation: Ensure minority performance holds on validation set (not just training)
- Use balanced accuracy: Primary metric for imbalanced datasets
- Consider cost: In medical applications, false negatives (missing converters) may be more costly than false positives
The combination of Focal Loss (α=0.90, γ=3.0) + minority forcing (20 epochs) + label smoothing (0.05) has been tuned for typical AD datasets with 15-25% converter ratio. Adjust based on your specific class distribution.