Skip to main content

The Imbalance Problem

Alzheimer’s disease progression datasets are inherently imbalanced:
  • Stable subjects (class 0): Majority class (typically 70-85%)
  • Converter subjects (class 1): Minority class (typically 15-30%)
Standard cross-entropy loss often leads to models that:
  • Predict “stable” for all subjects
  • Achieve high accuracy but zero minority recall
  • Fail to identify patients at risk of conversion

Class Distribution

Check your dataset distribution (main.py:311-313):
class_labels = np.array([data.y.item() for data in dataset])
print(f"Class distribution: {np.bincount(class_labels)}")
# Example: [1842  412] → 82% stable, 18% converters

Focal Loss

Overview

Focal Loss addresses class imbalance by:
  1. Down-weighting easy examples (high-confidence correct predictions)
  2. Up-weighting hard examples (low-confidence or misclassified)
  3. Class-based weighting (giving more importance to minority class)

Mathematical Formulation

Implemented in FocalLoss.py:14-38:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, label_smoothing=0.0):
        self.alpha = alpha      # Weight for positive class (converters)
        self.gamma = gamma      # Focusing parameter
        self.label_smoothing = label_smoothing
Formula:
FL(p_t) = -α_t (1 - p_t)^γ log(p_t)

where:
  p_t = model's probability for the true class
  α_t = α if y=1 (converter), else (1-α)
  γ = focusing parameter

Alpha Parameter

Controls class weighting (--focal_alpha):
alpha = 0.90
Weights:
  • Converters (class 1): 0.90
  • Stable (class 0): 0.10
Effect: Strong emphasis on detecting convertersUse When: Severe imbalance (< 20% converters)
Implementation (FocalLoss.py:29):
alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)

Gamma Parameter

Controls focusing strength (--focal_gamma):
gamma = 0.0
Effect: (1 - p_t)^0 = 1 → Standard weighted cross-entropyUse When: Want only class weighting, no difficulty weighting
Implementation (FocalLoss.py:26-32):
pt = torch.exp(-ce_loss)  # Probability of correct class
focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss

Label Smoothing

Prevents overconfident predictions (--label_smoothing):
if self.label_smoothing > 0:
    targets_one_hot = F.one_hot(targets, num_classes=2).float()
    # Smooth: [1, 0] → [0.95, 0.05] for label_smoothing=0.05
    targets_one_hot = targets_one_hot * (1 - self.label_smoothing) + \
                     self.label_smoothing / num_classes
    ce_loss = -torch.sum(targets_one_hot * F.log_softmax(inputs, dim=1), dim=1)
Effect:
  • Original: [1, 0] (hard label)
  • Smoothed (0.05): [0.95, 0.05] (soft label)
  • Smoothed (0.10): [0.90, 0.10] (softer label)
Benefits:
  • Reduces overfitting
  • Improves calibration
  • Helps with noisy labels
Reference: FocalLoss.py:17-21

Minority Class Forcing

Overview

An additional loss term applied during early training to encourage minority class predictions (main.py:465-482):
def minority_class_forcing_loss(logits, targets, epoch):
    if epoch > opt.minority_focus_epochs:  # Default: 20 epochs
        return torch.tensor(0.0)
    
    # Only apply to minority class samples
    minority_mask = (targets == 1)
    if minority_mask.sum() > 0:
        minority_logits = logits[minority_mask]
        # Encourage class 1 prediction
        minority_loss = F.cross_entropy(
            minority_logits,
            torch.ones(minority_mask.sum(), dtype=torch.long, device=device)
        )
        # Linearly decay forcing weight
        forcing_weight = 0.1 * (opt.minority_focus_epochs - epoch) / opt.minority_focus_epochs
        return forcing_weight * minority_loss
    
    return torch.tensor(0.0)

Forcing Schedule

Forcing weight decays over epochs:
EpochForcing WeightEffect
10.100Strong forcing
50.075Moderate forcing
100.050Mild forcing
150.025Weak forcing
200.000No forcing
21+0.000No forcing
Configuration: --minority_focus_epochs 20 (default)

Total Loss

Combination of Focal Loss and forcing loss (main.py:524-529):
for batch in train_loader:
    logits = classifier(graph_seq, None, lengths)
    
    # Main loss
    loss = criterion(logits, labels)  # Focal Loss
    
    # Additional forcing loss (early epochs only)
    forcing_loss = minority_class_forcing_loss(logits, labels, epoch)
    
    # Combined
    total_batch_loss = loss + forcing_loss
    total_batch_loss.backward()
The forcing loss is applied IN ADDITION to Focal Loss, not as a replacement. This provides extra emphasis on minority class during the critical early training phase.

Monitoring Class Predictions

Training Monitoring

Track prediction distribution during training (main.py:510-549):
class_0_preds = 0
class_1_preds = 0

for batch in train_loader:
    logits = classifier(graph_seq, None, lengths)
    preds = logits.argmax(dim=1)
    class_0_preds += (preds == 0).sum().item()
    class_1_preds += (preds == 1).sum().item()

print(f"Predictions: Stable={class_0_preds}, Converter={class_1_preds}")
Warning Signs:
  • All predictions are class 0: Model collapsed to majority class
  • All predictions are class 1: Forcing is too strong
  • Balanced predictions: Good, but check if they’re accurate

Validation Monitoring

Evaluate with detailed metrics (main.py:319-380):
val_results = evaluate_detailed(val_loader)

print(f"Unique predictions: {val_results['unique_preds']}")
print(f"Minority F1: {val_results['minority_f1']:.3f}")
print(f"Minority Recall: {val_results['minority_recall']:.3f}")
Model Selection Requirement (main.py:565-567):
if val_results['unique_preds'] > 1 and current_auc > best_auc:
    # Only save if model predicts both classes
    save_model(...)
Models that predict only one class (collapsed models) are never saved, even if they have high accuracy. This ensures saved models can distinguish both classes.

Evaluation Metrics

Balanced Accuracy

Average of per-class recalls (main.py:358-359):
balanced_acc = np.mean([recall[0], recall[1]])
# = (Stable_Recall + Converter_Recall) / 2
Why Important:
  • Standard accuracy can be misleading with imbalance
  • A model predicting all “stable” gets 80% accuracy but 50% balanced accuracy

Minority F1 Score

F1 score for converter class (main.py:356):
minority_f1 = f1[1]  # F1 for class 1 (converters)
# F1 = 2 * (Precision * Recall) / (Precision + Recall)
Interpretation:
  • F1 = 0.0: Model never predicts converters
  • F1 = 0.5: Reasonable performance
  • F1 = 0.7+: Good minority class detection

AUC-ROC

Area under ROC curve (main.py:362):
probs_positive = np.array(all_probs)[:, 1]
auc_score = roc_auc_score(all_targets, probs_positive)
Advantages:
  • Threshold-independent metric
  • Considers full probability distribution
  • Good for imbalanced datasets

Severe Imbalance (< 20% minority)

python main.py \
  --focal_alpha 0.95 \
  --focal_gamma 4.0 \
  --minority_focus_epochs 30 \
  --label_smoothing 0.1
Strategy: Maximum minority focus with extended forcing period

Moderate Imbalance (20-35% minority)

python main.py \
  --focal_alpha 0.90 \
  --focal_gamma 3.0 \
  --minority_focus_epochs 20 \
  --label_smoothing 0.05
Strategy: Default configuration (balanced approach)

Mild Imbalance (35-45% minority)

python main.py \
  --focal_alpha 0.75 \
  --focal_gamma 2.0 \
  --minority_focus_epochs 10 \
  --label_smoothing 0.05
Strategy: Lighter intervention, rely mainly on Focal Loss

No Forcing (Focal Loss Only)

python main.py \
  --focal_alpha 0.80 \
  --focal_gamma 2.0 \
  --minority_focus_epochs 0 \
  --label_smoothing 0.05
Strategy: Disable forcing, use only Focal Loss (set epochs to 0)

Gradient Clipping

Prevents exploding gradients during minority forcing (main.py:533-536):
torch.nn.utils.clip_grad_norm_(
    list(fold_encoder.parameters()) + list(classifier.parameters()),
    max_norm=1.0
)
Effect: Scales gradients if total norm exceeds 1.0
Gradient clipping is particularly important when using minority forcing, as the additional loss term can cause large gradient spikes early in training.

Troubleshooting

Model Predicts Only Majority Class

Symptoms:
  • All predictions are class 0
  • High accuracy but zero minority recall
  • unique_preds = 1
Solutions:
  1. Increase focal_alpha (e.g., 0.95)
  2. Increase focal_gamma (e.g., 4.0)
  3. Extend minority_focus_epochs (e.g., 30)
  4. Reduce learning rate (--lr 0.0005)
  5. Check if minority class has enough samples (< 10% is very hard)

Model Predicts Only Minority Class

Symptoms:
  • All predictions are class 1
  • Low accuracy
  • Forced too aggressively
Solutions:
  1. Decrease focal_alpha (e.g., 0.75)
  2. Decrease focal_gamma (e.g., 1.0)
  3. Reduce minority_focus_epochs (e.g., 10)
  4. Lower forcing weight in code (change 0.1 to 0.05)

Unstable Training

Symptoms:
  • Predictions oscillate between all-0 and all-1
  • Loss spikes
  • NaN gradients
Solutions:
  1. Add/increase label smoothing (--label_smoothing 0.1)
  2. Reduce learning rate (--lr 0.0005)
  3. Check gradient clipping is enabled
  4. Reduce batch size for more stable gradients
  5. Lower focal_gamma (extreme values like 5+ can destabilize)

Low Minority F1 Despite Predictions

Symptoms:
  • Model predicts some converters
  • But minority F1 < 0.3
  • Predictions are often wrong
Solutions:
  1. Model may be underfitting:
    • Increase model capacity (--gnn_hidden_dim 512 --lstm_hidden_dim 128)
    • Train longer (--n_epochs 150)
  2. Data quality issues:
    • Check for label noise
    • Verify converter definitions
  3. Reduce label smoothing (may be over-regularizing)

Class Weights Alternative

If Focal Loss is not working, try standard weighted cross-entropy:
# Instead of Focal Loss
class_counts = np.bincount(train_labels)
class_weights = torch.tensor(
    [1.0 / count for count in class_counts],
    dtype=torch.float32,
    device=device
)
criterion = nn.CrossEntropyLoss(weight=class_weights)
This is not currently implemented in main.py but can be added as an alternative to Focal Loss. You would need to modify the criterion initialization at main.py:305-309.

Best Practices

  1. Start with defaults: The default Focal Loss configuration works well for most AD datasets
  2. Monitor both classes: Don’t just look at accuracy - track minority recall and F1
  3. Gradual adjustments: Change one parameter at a time to understand effects
  4. Check validation: Ensure minority performance holds on validation set (not just training)
  5. Use balanced accuracy: Primary metric for imbalanced datasets
  6. Consider cost: In medical applications, false negatives (missing converters) may be more costly than false positives
The combination of Focal Loss (α=0.90, γ=3.0) + minority forcing (20 epochs) + label smoothing (0.05) has been tuned for typical AD datasets with 15-25% converter ratio. Adjust based on your specific class distribution.

Build docs developers (and LLMs) love