Skip to main content

Best Model Performance

The best performing model uses a GraphSAGE-LSTM architecture with TopK pooling and focal loss, achieving the following results on the ADNI test set:
MetricScore
Test Accuracy82.9%
Balanced Accuracy77.1%
AUC-ROC85.4%
Minority F1 (Converters)Reported in detailed metrics

Architecture Configuration

The best model configuration:

Graph Neural Network

ComponentValue
Layer TypeGraphSAGE
Hidden Dimension256
Number of Layers2
Activation FunctionELU
Pooling StrategyTopK (ratio=0.3)
Output Dimension256
Dropout0.2

Temporal Model

ComponentValue
Model TypeLSTM
Hidden Dimension64
Number of Layers1
BidirectionalYes
Dropout0.45
Max Visits10

Training Configuration

ParameterValue
Loss FunctionFocal Loss
Focal Alpha0.90
Focal Gamma3.0
Label Smoothing0.05
Learning Rate0.001
Batch Size16 subjects
Epochs100
OptimizerAdam (β₁=0.9, β₂=0.999)
Weight Decay1e-4
Gradient Clippingmax_norm=1.0

Training Strategy

Minority Class Handling

To address severe class imbalance (converters are minority class), the model employs multiple strategies:
  1. Focal Loss: Down-weights easy examples, focuses on hard cases
    criterion = FocalLoss(alpha=0.90, gamma=3.0, label_smoothing=0.05)
    
  2. Minority Class Forcing: Additional loss term for first 20 epochs
    forcing_weight = 0.1 * (minority_focus_epochs - epoch) / minority_focus_epochs
    
  3. Stratified Cross-Validation: Ensures balanced class distribution across folds

Learning Rate Scheduling

ReduceLROnPlateau scheduler monitors validation balanced accuracy:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=10, min_lr=1e-6
)
Learning rate is reduced by half if validation balanced accuracy doesn’t improve for 10 epochs.

Model Selection

Best model is selected based on validation AUC:
if (val_results['unique_preds'] > 1 and current_auc > best_auc):
    best_auc = val_results['auc']
    best_minority_f1 = val_results['minority_f1']
    best_balanced_acc = val_results['balanced_accuracy']
Requirement: Model must predict both classes (unique_preds > 1) to be valid.

Cross-Validation Results

5-fold stratified cross-validation with 80/20 train-val split:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
Final results report mean ± standard deviation across folds:
for metric, values in fold_results.items():
    mean = np.mean(values)
    std = np.std(values)
    print(f"{metric}: {mean:.3f} ± {std:.3f}")

Tracked Metrics

Across all folds, the following metrics are collected:
MetricDescription
test_accTest set overall accuracy
balanced_accTest set balanced accuracy
minority_f1Test set F1 for converter class
test_aucTest set AUC-ROC
train_accTraining set overall accuracy
balanced_train_accTraining set balanced accuracy

Data Split Strategy

Subject-Level Splitting

Data is split at the subject level (not visit level) to prevent data leakage:
# All visits from a subject go into the same split
for fold_idx, (train_val_idx, test_idx) in enumerate(skf.split(subjects, labels)):
    train_val_subjects = [subjects[i] for i in train_val_idx]
    test_subjects = [subjects[i] for i in test_idx]

Stratification

Subjects are stratified by their conversion label to maintain class balance:
subject_labels = {}
for subj_id, graphs in dataset.subject_graph_dict.items():
    subject_labels[subj_id] = graphs[0].y.item()  # Use baseline label

Train-Val-Test Split

Each fold:
  • 64% training
  • 16% validation
  • 20% test
train_idx, val_idx = train_test_split(
    np.arange(len(train_val_subjects)),
    test_size=0.2,
    random_state=seed + fold_idx,
    stratify=train_val_labels
)

Reproducibility

All random seeds are set for reproducibility:
def set_random_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
Seeds are reset at key points:
  • Before model creation
  • Before each fold
  • Before data loaders
  • Before classifier initialization
  • Before optimizer creation

Early Stopping

Patience-based early stopping monitors validation AUC:
patience = 20
patience_counter = 0

if current_auc > best_auc:
    patience_counter = 0
else:
    patience_counter += 1
Training continues for full 100 epochs, but best model is saved based on validation performance.

Model Checkpointing

Best model state is saved for each fold:
best_model_state = {
    'encoder': fold_encoder.state_dict().copy(),
    'classifier': classifier.state_dict().copy(),
    'epoch': epoch,
    'val_results': val_results
}

torch.save(best_model_state, os.path.join(opt.save_path, f'best_model_fold{fold}.pth'))

Per-Fold Evaluation

Each fold reports:
Starting FOLD 1/5
Test Loss: 0.3421 | Balanced Accuracy: 0.771
Prediction Distribution: {0: 134, 1: 36}

Classification Report:
              precision    recall  f1-score   support
      Stable       0.86      0.91      0.88       135
   Converter       0.67      0.54      0.60        35

Conversion-Specific Performance

After all folds, aggregated conversion analysis shows per-group accuracy:
CN-Stable:
  Overall: 225/260 correct (0.865)
  Stable predictions: 225/255 correct (0.882)
  Converter predictions: 0/5 correct (0.000)

MCI->AD:
  Overall: 90/120 correct (0.750)
  Stable predictions: 20/30 correct (0.667)
  Converter predictions: 70/90 correct (0.778)
See Conversion Analysis for details.

Command to Reproduce

To reproduce the best results:
python main.py \
  --model_type LSTM \
  --layer_type GraphSAGE \
  --gnn_hidden_dim 256 \
  --gnn_num_layers 2 \
  --gnn_activation elu \
  --use_topk_pooling \
  --topk_ratio 0.3 \
  --lstm_hidden_dim 64 \
  --lstm_num_layers 1 \
  --lstm_bidirectional True \
  --focal_alpha 0.90 \
  --focal_gamma 3.0 \
  --label_smoothing 0.05 \
  --minority_focus_epochs 20 \
  --n_epochs 100 \
  --batch_size 16 \
  --lr 0.001 \
  --num_folds 5 \
  --max_visits 10
Note: Most of these are default values and can be omitted.

Comparison with Baselines

The spatiotemporal approach outperforms:
  1. Static GNN: Using only baseline visit (no temporal information)
  2. Simple RNN on tabular features: Without graph structure
  3. Traditional ML: SVM, Random Forest on hand-crafted features
The combination of graph-based connectivity analysis and temporal sequence modeling is key to achieving strong performance on this challenging task.

Build docs developers (and LLMs) love