Skip to main content

Overview

STGNN uses a comprehensive set of evaluation metrics to assess model performance on the binary classification task of predicting Alzheimer’s disease progression (stable vs. converter). All metrics are calculated by the evaluate_detailed function in main.py:319-380.

Primary Metrics

Accuracy

Overall classification accuracy across all test subjects:
accuracy = np.mean(np.array(all_preds) == np.array(all_targets))
Measures the proportion of correct predictions (both stable and converter classes) out of total predictions.

Balanced Accuracy

Average of per-class recall to handle class imbalance:
balanced_acc = np.mean([recall[0], recall[1]])
Calculated as the mean of recall for stable subjects (class 0) and recall for converters (class 1). This metric is particularly important given the dataset’s class imbalance, where converters are the minority class. Why it matters: Standard accuracy can be misleading when classes are imbalanced. A model predicting “stable” for everyone might achieve high accuracy but would fail to identify any converters.

Area Under ROC Curve (AUC)

Discriminative ability between stable and converter classes:
probs_positive = np.array(all_probs)[:, 1]
auc_score = roc_auc_score(all_targets, probs_positive)
Measures the model’s ability to rank converters higher than stable subjects based on predicted probabilities. AUC ranges from 0 to 1, where:
  • 1.0 = perfect discrimination
  • 0.5 = random guessing
  • Values above 0.8 indicate strong predictive power

F1 Score

Harmonic mean of precision and recall, calculated separately for each class:
precision, recall, f1, _ = precision_recall_fscore_support(
    all_targets, all_preds, average=None, zero_division=0
)
Minority F1 (converters): The F1 score for class 1 is tracked as the primary metric for minority class performance:
minority_f1 = f1[1] if len(f1) > 1 else 0

Per-Class Metrics

Precision

Proportion of positive predictions that are correct:
Precision = True Positives / (True Positives + False Positives)
Tracked separately for stable (class 0) and converter (class 1) predictions:
minority_precision = precision[1] if len(precision) > 1 else 0
High precision for converters means fewer false alarms (stable patients incorrectly flagged as converters).

Recall (Sensitivity)

Proportion of actual positives correctly identified:
Recall = True Positives / (True Positives + False Negatives)
minority_recall = recall[1] if len(recall) > 1 else 0
High recall for converters means the model successfully identifies most patients who will actually convert.

Loss Function

Focal Loss

The model uses focal loss to address class imbalance:
criterion = FocalLoss(
    alpha=0.90,      # Weight for minority class
    gamma=3.0,       # Focusing parameter
    label_smoothing=0.05
)
Focal loss down-weights easy examples and focuses on hard-to-classify cases. The alpha parameter (0.90) gives more weight to the minority converter class. See FocalLoss.py for implementation details.

Cross-Validation Metrics

Metrics are aggregated across all folds in 5-fold stratified cross-validation:
fold_results = {
    'test_acc': [],
    'balanced_acc': [],
    'minority_f1': [],
    'test_auc': [],
    'train_acc': [],
    'balanced_train_acc': []
}
Final results report mean and standard deviation:
mean = np.mean(values)
std = np.std(values)
print(f"{metric}: {mean:.3f} ± {std:.3f}")

Model Selection Criteria

The best model is selected based on validation AUC:
if (val_results['unique_preds'] > 1 and current_auc > best_auc) or \
   (val_results['unique_preds'] > 1 and best_auc == 0):
    best_auc = val_results['auc']
    best_minority_f1 = val_results['minority_f1']
    best_balanced_acc = val_results['balanced_accuracy']
The model must predict both classes (unique_preds > 1) to be considered valid.

Evaluation Return Structure

The evaluate_detailed function returns a comprehensive dictionary:
result = {
    'loss': avg_loss,
    'accuracy': accuracy,
    'balanced_accuracy': balanced_acc,
    'minority_precision': minority_precision,
    'minority_recall': minority_recall,
    'minority_f1': minority_f1,
    'auc': auc_score,
    'predictions': all_preds,
    'targets': all_targets,
    'unique_preds': len(set(all_preds))
}
When return_probs=True is specified, predicted probabilities are also included:
if return_probs:
    result['probabilities'] = all_probs

Classification Report

Full scikit-learn classification report is printed for the test set:
print(classification_report(test_results['targets'], test_results['predictions'],
                            target_names=['Stable', 'Converter'],
                            zero_division=0))
This provides precision, recall, F1 score, and support for both classes.

Build docs developers (and LLMs) love