Skip to main content

Overview

The STGNN training pipeline combines graph neural networks with recurrent architectures to predict disease progression from longitudinal fMRI functional connectivity data. The system uses a two-stage architecture:
  1. Graph Encoder: Processes individual fMRI connectivity graphs into embeddings
  2. Temporal Predictor: Models progression patterns across patient visit sequences

Training Pipeline

Data Preparation

The training process begins by loading the FC_ADNIDataset from preprocessed .npz files containing functional connectivity matrices:
dataset = FC_ADNIDataset(
    root="/path/to/ADNI-Data",
    var_name="fc_matrix"
)
Key preprocessing steps in main.py:88-116:
  • Replace infinite values with zeros
  • Create subject-to-graph mappings
  • Trim sequences to max_visits (default: 10)
  • Group multiple scans per visit by subject ID

Training Architecture

The complete model consists of: Graph Encoder (GraphNeuralNetwork)
  • Input: 100-dimensional node features (brain regions)
  • GNN layers with GraphNorm and configurable activation
  • TopK pooling or global mean/max pooling
  • Output: 512-dimensional graph embeddings (256 × 2 from mean+max pooling)
Temporal Predictor (LSTM/GRU/RNN)
  • Input: Sequence of graph embeddings
  • Packed sequence processing for variable-length inputs
  • Bidirectional option for LSTM
  • Final hidden state classification

Training Loop

The training loop (main.py:499-588) implements:
  1. Epoch iteration for n_epochs (default: 100)
  2. Batch processing through TemporalDataLoader
  3. Loss computation using Focal Loss + minority class forcing
  4. Gradient clipping (max_norm=1.0)
  5. Learning rate scheduling via ReduceLROnPlateau
  6. Model checkpointing based on validation AUC
for epoch in range(1, opt.n_epochs + 1):
    fold_encoder.eval()  # Frozen for feature extraction
    classifier.train()
    
    for batch in train_loader:
        optimizer.zero_grad()
        logits = classifier(graph_seq, None, lengths)
        loss = criterion(logits, labels)
        forcing_loss = minority_class_forcing_loss(logits, labels, epoch)
        total_loss = loss + forcing_loss
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(parameters, max_norm=1.0)
        optimizer.step()

Optimizer Configuration

Adam optimizer with component-specific learning rates (main.py:294-297):
optimizer = torch.optim.Adam([
    {'params': encoder_params, 'lr': opt.lr},
    {'params': classifier_params, 'lr': opt.lr}
], betas=(0.9, 0.999), weight_decay=1e-4)

Learning Rate Schedule

ReduceLROnPlateau monitors validation balanced accuracy:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=10, min_lr=1e-6
)
Reduces LR by 50% after 10 epochs without improvement.

Evaluation Metrics

The evaluate_detailed function (main.py:319-380) computes:
  • Accuracy: Overall classification accuracy
  • Balanced Accuracy: Mean of per-class recalls
  • Minority F1: F1-score for converter class (class 1)
  • AUC-ROC: Area under ROC curve
  • Per-class Precision/Recall: For both stable and converter classes

Model Selection

Best model selection criteria (main.py:565-578):
  • Primary: Validation AUC (higher is better)
  • Requirement: Model must predict both classes (unique_preds > 1)
  • Early stopping: 20 epochs patience (hardcoded)

Reproducibility

All random seeds are set for reproducibility (main.py:54-62):
def set_random_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
Seeds are reset before:
  • Model initialization
  • Optimizer creation
  • Each fold split

Output and Logging

Training outputs include:
  • Epoch-level train/validation metrics
  • Class prediction distribution monitoring
  • Learning rate changes
  • Best model checkpoints saved to save_path
  • Per-fold test set evaluation with classification reports
  • Cross-validation summary with mean ± std
The encoder is kept in eval() mode during training when using pretrained weights, serving only as a feature extractor while the temporal classifier learns progression patterns.
Ensure sufficient GPU memory for batch processing. The system creates large batches by concatenating all graphs from multiple subjects before computing embeddings in a single forward pass.

Next Steps

Build docs developers (and LLMs) love