Skip to main content

Quick start

This guide will help you set up your data and run your first STGNN training session.

Prerequisites

Before you begin, ensure you have:
  • Completed the installation steps
  • Access to ADNI fMRI data or similar functional connectivity data
  • At least 16GB RAM (32GB recommended for larger datasets)

Data setup

STGNN requires functional connectivity matrices and patient labels organized in a specific structure.
1

Organize FC matrices

Place your functional connectivity (FC) matrices in the data/FC_Matrices/ directory.File naming format:
sub-XXXXXX_run-XX_fc_matrix.npz
File structure: Each .npz file should contain a key named fc_matrix with the connectivity matrix.
import numpy as np

# Example: 100x100 correlation matrix
fc_matrix = np.corrcoef(fmri_time_series.T)

# Save with correct format
np.savez('data/FC_Matrices/sub-123456_run-01_fc_matrix.npz', 
         fc_matrix=fc_matrix)
2

Prepare patient labels

Create the temporal labels file: data/TADPOLE_TEMPORAL.csvRequired columns:
  • Subject: Patient identifier (e.g., “sub-123456”)
  • Visit: Visit identifier or code
  • Label_CS_Num: Binary label (0=Stable, 1=Converter)
  • Visit_Order: Sequential visit number (1, 2, 3, …)
  • Months_From_Baseline: Months since first visit
  • Months_To_Next_Original: Months until next visit (-1 if last visit)
Example CSV structure:
Subject,Visit,Label_CS_Num,Visit_Order,Months_From_Baseline,Months_To_Next_Original
sub-123456,bl,0,1,0,6.2
sub-123456,m06,0,2,6.2,6.1
sub-123456,m12,1,3,12.3,-1
sub-789012,bl,0,1,0,12.5
sub-789012,m12,0,2,12.5,-1
If you have TADPOLE_Simplified.csv and TADPOLE_COMPLETE.csv from ADNI, you can generate TADPOLE_TEMPORAL.csv automatically:
python setup_temporal_data.py
3

Create output directory

Create a directory for saving trained models:
mkdir -p model
The output directory is created automatically if it doesn’t exist, but creating it manually ensures proper permissions.

Generate temporal data (optional)

If you have original TADPOLE data files, generate the temporal dataset:
python setup_temporal_data.py
This script:
  • Loads TADPOLE_COMPLETE.csv and TADPOLE_Simplified.csv
  • Calculates temporal gaps between visits
  • Creates sequential visit orders
  • Generates TADPOLE_TEMPORAL.csv with all required columns

Run your first training

Now you’re ready to train your first model!

Basic training

Run with default settings (GraphSAGE-LSTM with TopK pooling and focal loss):
python main.py
The default configuration uses:
  • GNN: GraphSAGE with 2 layers, 256 hidden dimensions
  • Temporal model: Bidirectional LSTM with 64 hidden dimensions
  • Pooling: TopK pooling (ratio=0.3)
  • Training: 100 epochs, batch size 16, 5-fold cross-validation

Training with custom parameters

Customize the architecture and training parameters:
# Use GAT instead of GraphSAGE
python main.py --layer_type GAT --gnn_hidden_dim 512 --gnn_num_layers 3

# Use GCN with custom activation
python main.py --layer_type GCN --gnn_activation relu --gnn_num_layers 2

Time-aware prediction (experimental)

Enable temporal gap features for time-aware prediction:
python main.py --use_time_features --exclude_target_visit
When using --use_time_features, you must also use --exclude_target_visit to prevent data leakage from the target visit.

Understanding the output

During training, you’ll see output like this:
Loaded 450 fMRI FC graph samples
Created subject graph mapping with 150 subjects
Using Temporal LSTM Model
  - LSTM Hidden Dim: 64
  - LSTM Layers: 1
  - Bidirectional: True
  - Temporal Batch Size: 16 subjects/batch
  - Graph Pooling: TopK (ratio=0.3)
Training for 100 epochs
Minority class forcing for first 20 epochs

==================================================
STARTING FOLD 1/5
==================================================

Processing 90 temporal sequence batches per epoch

Starting LSTM Training for Fold 1
Epoch 1/100 | Loss: 0.5234 | Train Acc: 0.654 | Val Acc: 0.621 | Val Balanced: 0.589
New best model saved (AUC: 0.742, Balanced Acc: 0.589)
...

Key metrics explained

  • Test Accuracy: Overall classification accuracy
  • Balanced Accuracy: Average of per-class recall (accounts for class imbalance)
  • Minority F1: F1 score for the converter class (minority class)
  • AUC: Area under the ROC curve (model’s ability to discriminate between classes)

Final cross-validation results

After all folds complete, you’ll see aggregated results:
Cross-Validation Summary:
test_acc: 0.829 ± 0.023
balanced_acc: 0.771 ± 0.031
minority_f1: 0.682 ± 0.045
test_auc: 0.854 ± 0.028
train_acc: 0.891 ± 0.019
balanced_train_acc: 0.843 ± 0.025

Saved models

Trained models are saved in the model/ directory:
model/
├── best_model_fold0.pth
├── best_model_fold1.pth
├── best_model_fold2.pth
├── best_model_fold3.pth
└── best_model_fold4.pth
Each model file contains:
  • Trained encoder (GNN) state dict
  • Trained classifier (LSTM/GRU/RNN) state dict
  • Best epoch number
  • Validation metrics

Common configuration examples

Best performance (default)

GraphSAGE-LSTM configuration (82.9% test accuracy):
python main.py \
  --layer_type GraphSAGE \
  --gnn_hidden_dim 256 \
  --gnn_num_layers 2 \
  --model_type LSTM \
  --lstm_hidden_dim 64 \
  --lstm_bidirectional True \
  --use_topk_pooling \
  --topk_ratio 0.3 \
  --focal_alpha 0.90 \
  --focal_gamma 3.0

Fast training

Quick experimentation with smaller model:
python main.py \
  --gnn_hidden_dim 128 \
  --lstm_hidden_dim 32 \
  --n_epochs 50 \
  --batch_size 32

Memory-efficient

For systems with limited GPU memory:
python main.py \
  --batch_size 8 \
  --gnn_hidden_dim 128 \
  --lstm_hidden_dim 32 \
  --max_visits 5

Next steps

Core concepts

Learn about the architecture and methodology

Configuration reference

Explore all available configuration options

Model architecture

Understand the GNN and temporal components

Advanced usage

Pretraining, transfer learning, and hyperparameter tuning

Troubleshooting

”No module named ‘torch_geometric’”

Ensure PyTorch Geometric is installed:
pip install torch-geometric torch-scatter

“CUDA out of memory”

Reduce memory usage:
python main.py --batch_size 8 --gnn_hidden_dim 128

“No such file or directory: ‘data/TADPOLE_TEMPORAL.csv’”

Ensure your data files are in the correct location:
ls data/
# Should show: FC_Matrices/ TADPOLE_TEMPORAL.csv

Poor performance on minority class

Adjust focal loss parameters to focus more on minority class:
python main.py --focal_alpha 0.95 --focal_gamma 4.0 --minority_focus_epochs 30

Build docs developers (and LLMs) love