Skip to main content

Overview

The STGNN training uses stratified k-fold cross-validation to ensure robust evaluation and prevent overfitting. The default configuration performs 5-fold cross-validation with stratified splitting at the subject level.

Stratification Strategy

Subject-Level Splitting

Cross-validation splits are performed at the subject level, not the visit level (main.py:117-169):
def get_kfold_splits(dataset, num_folds=5, seed=42):
    # Extract one label per subject
    subject_labels = {}
    for subj_id, graphs in dataset.subject_graph_dict.items():
        if graphs and hasattr(graphs[0], 'y'):
            subject_labels[subj_id] = graphs[0].y.item()
    
    subjects = list(subject_labels.keys())
    labels = [subject_labels[s] for s in subjects]
    
    # Stratified k-fold on subjects
    skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=seed)
Why Subject-Level?
  • Prevents data leakage (visits from same subject in train and test)
  • Reflects real-world deployment (predicting for new patients)
  • More conservative performance estimates

Stratification

Stratification ensures balanced class distribution across folds:
for fold_idx, (train_val_idx, test_idx) in enumerate(skf.split(subjects, labels)):
    train_val_subjects = [subjects[i] for i in train_val_idx]
    test_subjects = [subjects[i] for i in test_idx]
Class Balance:
  • Each fold maintains approximately the same ratio of stable vs. converter subjects
  • Prevents folds with all/no converters
  • Critical for imbalanced datasets

Split Configuration

Three-Way Split

Each fold is divided into train, validation, and test sets (main.py:134-139):
# 80% train+val, 20% test (from stratified k-fold)
train_val_subjects, test_subjects = skf.split(...)

# Further split train+val into 80% train, 20% val
train_idx, val_idx = train_test_split(
    np.arange(len(train_val_subjects)),
    test_size=0.2,
    random_state=seed + fold_idx,
    stratify=train_val_labels
)
Final Ratios (for 5-fold CV):
  • Train: 64% of total subjects
  • Validation: 16% of total subjects
  • Test: 20% of total subjects
Example (100 subjects):
  • Fold 1: 64 train, 16 val, 20 test
  • Fold 2: 64 train, 16 val, 20 test
  • Fold 5: 64 train, 16 val, 20 test

Visit Index Mapping

After subject-level splits, map subjects back to visit indices (main.py:144-154):
def get_visit_indices(subj_list):
    indices = []
    subj_set = set(subj_list)
    for i, data in enumerate(dataset):
        sid = getattr(data, 'subj_id', None)
        if sid is None:
            continue
        base_subject_id = sid.split('_run')[0] if '_run' in sid else sid
        if base_subject_id in subj_set:
            indices.append(i)
    return indices

tr_index = get_visit_indices(train_subjects)
val_index = get_visit_indices(val_subjects)
te_index = get_visit_indices(test_subjects)
Result: Each split contains ALL visits for subjects in that split.

Fold Configuration

Number of Folds

Configurable via --num_folds argument:
# 5-fold cross-validation (default)
python main.py --num_folds 5

# 10-fold cross-validation (more conservative)
python main.py --num_folds 10

# Single train/val/test split
python main.py --num_folds 1
--num_folds 1 performs a single 64/16/20 train/val/test split. Useful for faster experimentation or when you have abundant data.

Fold Independence

Each fold trains a completely independent model (main.py:201-228):
for fold, split in enumerate(fold_splits):
    print(f"STARTING FOLD {fold + 1}/{opt.num_folds}")
    
    # Create fresh copy of encoder for this fold
    set_random_seeds(42)
    fold_encoder = copy.deepcopy(encoder).to(device)
    
    # Fresh optimizer and scheduler
    optimizer = torch.optim.Adam([...], lr=opt.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(...)
    
    # Fresh criterion
    criterion = FocalLoss(alpha=opt.focal_alpha, gamma=opt.focal_gamma)
No Information Leakage:
  • Models do not share weights between folds
  • Each fold sees a different test set
  • Training restarts from scratch (or pretrained weights)

Reproducibility

Deterministic Splitting

Random seeds ensure identical splits across runs:
# Global seed for cross-validation splits
set_random_seeds(42)

# StratifiedKFold uses fixed seed
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)

# Validation split uses fold-dependent seed
train_test_split(..., random_state=seed + fold_idx, ...)

Data Loader Seeding

TemporalDataLoader uses seeded shuffling (TemporalDataLoader.py:28-32):
if self.shuffle:
    rng = np.random.RandomState(self.seed)
    rng.shuffle(self.subjects)
Result: Subject order is shuffled but reproducible within each fold.

Training Per Fold

Independent Training

Each fold runs complete training loop (main.py:499-588):
for epoch in range(1, opt.n_epochs + 1):
    # Training
    for batch in train_loader:
        # Forward, backward, optimize
        ...
    
    # Validation
    val_results = evaluate_detailed(val_loader)
    
    # Model selection based on validation AUC
    if val_results['auc'] > best_auc:
        best_auc = val_results['auc']
        save_model(...)
Model Selection:
  • Each fold saves its own best model based on validation AUC
  • Best model is loaded for final test evaluation
  • Early stopping after 20 epochs without improvement

Per-Fold Test Evaluation

After training, evaluate on fold’s test set (main.py:604-638):
# Load best model for this fold
if best_model_state is not None:
    fold_encoder.load_state_dict_flexible(best_model_state['encoder'])
    classifier.load_state_dict(best_model_state['classifier'])

# Test set evaluation
test_results = evaluate_detailed(test_loader, return_probs=True)

print(classification_report(
    test_results['targets'],
    test_results['predictions'],
    target_names=['Stable', 'Converter']
))

Cross-Validation Results

Per-Fold Metrics

Metrics collected for each fold (main.py:73-80, main.py:640-645):
fold_results = {
    'test_acc': [],
    'balanced_acc': [],
    'minority_f1': [],
    'test_auc': [],
    'train_acc': [],
    'balanced_train_acc': []
}

# After each fold
fold_results['test_acc'].append(test_results['accuracy'])
fold_results['balanced_acc'].append(test_results['balanced_accuracy'])
fold_results['minority_f1'].append(test_results['minority_f1'])
fold_results['test_auc'].append(test_results['auc'])

Aggregate Summary

Final cross-validation summary (main.py:647-654):
print("\nCross-Validation Summary:")
for metric, values in fold_results.items():
    if values:
        mean = np.mean(values)
        std = np.std(values)
        print(f"{metric}: {mean:.3f} ± {std:.3f}")
Example Output:
Cross-Validation Summary:
test_acc: 0.742 ± 0.038
balanced_acc: 0.718 ± 0.045
minority_f1: 0.654 ± 0.062
test_auc: 0.801 ± 0.029
train_acc: 0.856 ± 0.021
balanced_train_acc: 0.834 ± 0.028
The standard deviation (±) indicates variability across folds. Lower std suggests more stable model performance.

Conversion Tracking

Per-fold conversion-specific analysis (main.py:630-638):
conversion_results = analyze_conversion_predictions(
    test_subjects,
    test_results['predictions'],
    test_results['targets'],
    label_csv_path
)
print_conversion_accuracy_report(conversion_results)
fold_conversion_results.append(conversion_results)
Aggregated Results (main.py:657-659):
if fold_conversion_results:
    aggregated = aggregate_conversion_results(fold_conversion_results)
    print_conversion_accuracy_report(aggregated)
Analyzes predictions for different conversion trajectories (e.g., CN→MCI, MCI→AD).

Validation Usage

Learning Rate Scheduling

Validation metrics guide learning rate adjustments (main.py:558):
scheduler.step(val_results['balanced_accuracy'])
ReduceLROnPlateau:
  • Monitors validation balanced accuracy
  • Reduces LR by 50% after 10 epochs without improvement
  • Minimum LR: 1e-6

Early Stopping

Implicit early stopping via patience (main.py:565-587):
if current_auc > best_auc:
    patience_counter = 0
    save_model(...)
else:
    patience_counter += 1
    # Training continues but best model remains unchanged
Patience: 20 epochs (hardcoded, not configurable)
Early stopping does NOT terminate training. Training continues for full n_epochs, but only the best model (by validation AUC) is saved and evaluated on test set.

Model Checkpoints

Each fold saves its best model (main.py:582-584):
if opt.save_model:
    torch.save(
        best_model_state,
        os.path.join(opt.save_path, f'best_model_fold{fold}.pth')
    )
Checkpoint Contents:
best_model_state = {
    'encoder': fold_encoder.state_dict().copy(),
    'classifier': classifier.state_dict().copy(),
    'epoch': epoch,
    'val_results': val_results
}
File Structure:
./model/
├── best_model_fold0.pth
├── best_model_fold1.pth
├── best_model_fold2.pth
├── best_model_fold3.pth
└── best_model_fold4.pth

Best Practices

For Small Datasets

# Use more folds for better coverage
python main.py --num_folds 10

For Large Datasets

# Single split for faster iteration
python main.py --num_folds 1

For Final Evaluation

# Standard 5-fold with multiple runs
for seed in 42 43 44 45 46; do
    python main.py --num_folds 5 # (Note: would need to add --seed argument)
done
The current implementation uses a fixed seed (42). To run multiple seeds, you would need to modify main.py to accept a --seed argument.

Troubleshooting

Unbalanced Folds

Symptom: Some folds have very few converters Solution: Ensure stratification is working:
# Check class distribution per fold
for fold, split in enumerate(fold_splits):
    stable, converter = count_classes(split['train_subjects'])
    print(f"Fold {fold}: Stable={stable}, Converter={converter}")

High Variance Across Folds

Symptom: Large standard deviation in cross-validation summary Possible Causes:
  • Small dataset (< 100 subjects)
  • One fold has unusual data distribution
  • Model underfitting or overfitting
Solutions:
  • Increase num_folds for more stable estimates
  • Check per-fold performance to identify outliers
  • Adjust model capacity or regularization

Memory Issues

Symptom: OOM error during batched encoding Solution: Reduce batch size or max visits:
python main.py --batch_size 8 --max_visits 5
Reducing batch_size below 4 may lead to unstable training due to small batch statistics in LSTM processing.

Build docs developers (and LLMs) love