Overview
The STGNN training uses stratified k-fold cross-validation to ensure robust evaluation and prevent overfitting. The default configuration performs 5-fold cross-validation with stratified splitting at the subject level.
Stratification Strategy
Subject-Level Splitting
Cross-validation splits are performed at the subject level, not the visit level (main.py:117-169):
def get_kfold_splits(dataset, num_folds=5, seed=42):
# Extract one label per subject
subject_labels = {}
for subj_id, graphs in dataset.subject_graph_dict.items():
if graphs and hasattr(graphs[0], 'y'):
subject_labels[subj_id] = graphs[0].y.item()
subjects = list(subject_labels.keys())
labels = [subject_labels[s] for s in subjects]
# Stratified k-fold on subjects
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=seed)
Why Subject-Level?
- Prevents data leakage (visits from same subject in train and test)
- Reflects real-world deployment (predicting for new patients)
- More conservative performance estimates
Stratification
Stratification ensures balanced class distribution across folds:
for fold_idx, (train_val_idx, test_idx) in enumerate(skf.split(subjects, labels)):
train_val_subjects = [subjects[i] for i in train_val_idx]
test_subjects = [subjects[i] for i in test_idx]
Class Balance:
- Each fold maintains approximately the same ratio of stable vs. converter subjects
- Prevents folds with all/no converters
- Critical for imbalanced datasets
Split Configuration
Three-Way Split
Each fold is divided into train, validation, and test sets (main.py:134-139):
# 80% train+val, 20% test (from stratified k-fold)
train_val_subjects, test_subjects = skf.split(...)
# Further split train+val into 80% train, 20% val
train_idx, val_idx = train_test_split(
np.arange(len(train_val_subjects)),
test_size=0.2,
random_state=seed + fold_idx,
stratify=train_val_labels
)
Final Ratios (for 5-fold CV):
- Train: 64% of total subjects
- Validation: 16% of total subjects
- Test: 20% of total subjects
Example (100 subjects):
- Fold 1: 64 train, 16 val, 20 test
- Fold 2: 64 train, 16 val, 20 test
- …
- Fold 5: 64 train, 16 val, 20 test
Visit Index Mapping
After subject-level splits, map subjects back to visit indices (main.py:144-154):
def get_visit_indices(subj_list):
indices = []
subj_set = set(subj_list)
for i, data in enumerate(dataset):
sid = getattr(data, 'subj_id', None)
if sid is None:
continue
base_subject_id = sid.split('_run')[0] if '_run' in sid else sid
if base_subject_id in subj_set:
indices.append(i)
return indices
tr_index = get_visit_indices(train_subjects)
val_index = get_visit_indices(val_subjects)
te_index = get_visit_indices(test_subjects)
Result: Each split contains ALL visits for subjects in that split.
Fold Configuration
Number of Folds
Configurable via --num_folds argument:
# 5-fold cross-validation (default)
python main.py --num_folds 5
# 10-fold cross-validation (more conservative)
python main.py --num_folds 10
# Single train/val/test split
python main.py --num_folds 1
--num_folds 1 performs a single 64/16/20 train/val/test split. Useful for faster experimentation or when you have abundant data.
Fold Independence
Each fold trains a completely independent model (main.py:201-228):
for fold, split in enumerate(fold_splits):
print(f"STARTING FOLD {fold + 1}/{opt.num_folds}")
# Create fresh copy of encoder for this fold
set_random_seeds(42)
fold_encoder = copy.deepcopy(encoder).to(device)
# Fresh optimizer and scheduler
optimizer = torch.optim.Adam([...], lr=opt.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(...)
# Fresh criterion
criterion = FocalLoss(alpha=opt.focal_alpha, gamma=opt.focal_gamma)
No Information Leakage:
- Models do not share weights between folds
- Each fold sees a different test set
- Training restarts from scratch (or pretrained weights)
Reproducibility
Deterministic Splitting
Random seeds ensure identical splits across runs:
# Global seed for cross-validation splits
set_random_seeds(42)
# StratifiedKFold uses fixed seed
skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
# Validation split uses fold-dependent seed
train_test_split(..., random_state=seed + fold_idx, ...)
Data Loader Seeding
TemporalDataLoader uses seeded shuffling (TemporalDataLoader.py:28-32):
if self.shuffle:
rng = np.random.RandomState(self.seed)
rng.shuffle(self.subjects)
Result: Subject order is shuffled but reproducible within each fold.
Training Per Fold
Independent Training
Each fold runs complete training loop (main.py:499-588):
for epoch in range(1, opt.n_epochs + 1):
# Training
for batch in train_loader:
# Forward, backward, optimize
...
# Validation
val_results = evaluate_detailed(val_loader)
# Model selection based on validation AUC
if val_results['auc'] > best_auc:
best_auc = val_results['auc']
save_model(...)
Model Selection:
- Each fold saves its own best model based on validation AUC
- Best model is loaded for final test evaluation
- Early stopping after 20 epochs without improvement
Per-Fold Test Evaluation
After training, evaluate on fold’s test set (main.py:604-638):
# Load best model for this fold
if best_model_state is not None:
fold_encoder.load_state_dict_flexible(best_model_state['encoder'])
classifier.load_state_dict(best_model_state['classifier'])
# Test set evaluation
test_results = evaluate_detailed(test_loader, return_probs=True)
print(classification_report(
test_results['targets'],
test_results['predictions'],
target_names=['Stable', 'Converter']
))
Cross-Validation Results
Per-Fold Metrics
Metrics collected for each fold (main.py:73-80, main.py:640-645):
fold_results = {
'test_acc': [],
'balanced_acc': [],
'minority_f1': [],
'test_auc': [],
'train_acc': [],
'balanced_train_acc': []
}
# After each fold
fold_results['test_acc'].append(test_results['accuracy'])
fold_results['balanced_acc'].append(test_results['balanced_accuracy'])
fold_results['minority_f1'].append(test_results['minority_f1'])
fold_results['test_auc'].append(test_results['auc'])
Aggregate Summary
Final cross-validation summary (main.py:647-654):
print("\nCross-Validation Summary:")
for metric, values in fold_results.items():
if values:
mean = np.mean(values)
std = np.std(values)
print(f"{metric}: {mean:.3f} ± {std:.3f}")
Example Output:
Cross-Validation Summary:
test_acc: 0.742 ± 0.038
balanced_acc: 0.718 ± 0.045
minority_f1: 0.654 ± 0.062
test_auc: 0.801 ± 0.029
train_acc: 0.856 ± 0.021
balanced_train_acc: 0.834 ± 0.028
The standard deviation (±) indicates variability across folds. Lower std suggests more stable model performance.
Conversion Tracking
Per-fold conversion-specific analysis (main.py:630-638):
conversion_results = analyze_conversion_predictions(
test_subjects,
test_results['predictions'],
test_results['targets'],
label_csv_path
)
print_conversion_accuracy_report(conversion_results)
fold_conversion_results.append(conversion_results)
Aggregated Results (main.py:657-659):
if fold_conversion_results:
aggregated = aggregate_conversion_results(fold_conversion_results)
print_conversion_accuracy_report(aggregated)
Analyzes predictions for different conversion trajectories (e.g., CN→MCI, MCI→AD).
Validation Usage
Learning Rate Scheduling
Validation metrics guide learning rate adjustments (main.py:558):
scheduler.step(val_results['balanced_accuracy'])
ReduceLROnPlateau:
- Monitors validation balanced accuracy
- Reduces LR by 50% after 10 epochs without improvement
- Minimum LR: 1e-6
Early Stopping
Implicit early stopping via patience (main.py:565-587):
if current_auc > best_auc:
patience_counter = 0
save_model(...)
else:
patience_counter += 1
# Training continues but best model remains unchanged
Patience: 20 epochs (hardcoded, not configurable)
Early stopping does NOT terminate training. Training continues for full n_epochs, but only the best model (by validation AUC) is saved and evaluated on test set.
Model Checkpoints
Each fold saves its best model (main.py:582-584):
if opt.save_model:
torch.save(
best_model_state,
os.path.join(opt.save_path, f'best_model_fold{fold}.pth')
)
Checkpoint Contents:
best_model_state = {
'encoder': fold_encoder.state_dict().copy(),
'classifier': classifier.state_dict().copy(),
'epoch': epoch,
'val_results': val_results
}
File Structure:
./model/
├── best_model_fold0.pth
├── best_model_fold1.pth
├── best_model_fold2.pth
├── best_model_fold3.pth
└── best_model_fold4.pth
Best Practices
For Small Datasets
# Use more folds for better coverage
python main.py --num_folds 10
For Large Datasets
# Single split for faster iteration
python main.py --num_folds 1
For Final Evaluation
# Standard 5-fold with multiple runs
for seed in 42 43 44 45 46; do
python main.py --num_folds 5 # (Note: would need to add --seed argument)
done
The current implementation uses a fixed seed (42). To run multiple seeds, you would need to modify main.py to accept a --seed argument.
Troubleshooting
Unbalanced Folds
Symptom: Some folds have very few converters
Solution: Ensure stratification is working:
# Check class distribution per fold
for fold, split in enumerate(fold_splits):
stable, converter = count_classes(split['train_subjects'])
print(f"Fold {fold}: Stable={stable}, Converter={converter}")
High Variance Across Folds
Symptom: Large standard deviation in cross-validation summary
Possible Causes:
- Small dataset (< 100 subjects)
- One fold has unusual data distribution
- Model underfitting or overfitting
Solutions:
- Increase
num_folds for more stable estimates
- Check per-fold performance to identify outliers
- Adjust model capacity or regularization
Memory Issues
Symptom: OOM error during batched encoding
Solution: Reduce batch size or max visits:
python main.py --batch_size 8 --max_visits 5
Reducing batch_size below 4 may lead to unstable training due to small batch statistics in LSTM processing.