Overview
The STGNN training pipeline combines graph neural networks with recurrent architectures to predict disease progression from longitudinal fMRI functional connectivity data. The system uses a two-stage architecture:- Graph Encoder: Processes individual fMRI connectivity graphs into embeddings
- Temporal Predictor: Models progression patterns across patient visit sequences
Training Pipeline
Data Preparation
The training process begins by loading the FC_ADNIDataset from preprocessed.npz files containing functional connectivity matrices:
main.py:88-116:
- Replace infinite values with zeros
- Create subject-to-graph mappings
- Trim sequences to
max_visits(default: 10) - Group multiple scans per visit by subject ID
Training Architecture
The complete model consists of: Graph Encoder (GraphNeuralNetwork)
- Input: 100-dimensional node features (brain regions)
- GNN layers with GraphNorm and configurable activation
- TopK pooling or global mean/max pooling
- Output: 512-dimensional graph embeddings (256 × 2 from mean+max pooling)
- Input: Sequence of graph embeddings
- Packed sequence processing for variable-length inputs
- Bidirectional option for LSTM
- Final hidden state classification
Training Loop
The training loop (main.py:499-588) implements:
- Epoch iteration for
n_epochs(default: 100) - Batch processing through TemporalDataLoader
- Loss computation using Focal Loss + minority class forcing
- Gradient clipping (max_norm=1.0)
- Learning rate scheduling via ReduceLROnPlateau
- Model checkpointing based on validation AUC
Optimizer Configuration
Adam optimizer with component-specific learning rates (main.py:294-297):
Learning Rate Schedule
ReduceLROnPlateau monitors validation balanced accuracy:Evaluation Metrics
Theevaluate_detailed function (main.py:319-380) computes:
- Accuracy: Overall classification accuracy
- Balanced Accuracy: Mean of per-class recalls
- Minority F1: F1-score for converter class (class 1)
- AUC-ROC: Area under ROC curve
- Per-class Precision/Recall: For both stable and converter classes
Model Selection
Best model selection criteria (main.py:565-578):
- Primary: Validation AUC (higher is better)
- Requirement: Model must predict both classes (
unique_preds > 1) - Early stopping: 20 epochs patience (hardcoded)
Reproducibility
All random seeds are set for reproducibility (main.py:54-62):
- Model initialization
- Optimizer creation
- Each fold split
Output and Logging
Training outputs include:- Epoch-level train/validation metrics
- Class prediction distribution monitoring
- Learning rate changes
- Best model checkpoints saved to
save_path - Per-fold test set evaluation with classification reports
- Cross-validation summary with mean ± std
The encoder is kept in
eval() mode during training when using pretrained weights, serving only as a feature extractor while the temporal classifier learns progression patterns.Next Steps
- Configuration - Complete list of command-line arguments
- Model Architectures - GNN and RNN architecture options
- Cross-Validation - Stratified k-fold setup
- Class Imbalance - Focal loss and minority handling