Skip to main content

Welcome to STGNN

STGNN is a Spatiotemporal Graph Neural Network framework for predicting Alzheimer’s disease progression by combining Graph Neural Networks for brain connectivity analysis with RNNs for temporal modeling across patient visits.

What is STGNN?

STGNN predicts cognitive stage conversion in Alzheimer’s disease patients using resting-state functional MRI (rs-fMRI) data from the Alzheimer’s Disease Neuroimaging Initiative (ADNI). The framework analyzes brain connectivity patterns over time to forecast disease progression with state-of-the-art accuracy.

Key performance metrics

The best-performing GraphSAGE-LSTM architecture achieves:
  • 82.9% test accuracy
  • 77.1% balanced accuracy
  • 85.4% AUC (Area Under the Curve)
This work is based on the paper: “Predicting Alzheimer’s Disease Progression Using rs-fMRI and a History-Aware Graph Neural Network”

Key features

Graph Neural Networks

Configurable GNN architectures (GCN, GAT, GraphSAGE) with TopK pooling for hierarchical brain connectivity analysis

Temporal modeling

LSTM, GRU, and RNN models for capturing disease progression patterns across multiple patient visits

Focal loss

Advanced focal loss implementation to handle class imbalance between stable and converter patients

Cross-validation

Built-in 5-fold stratified cross-validation with comprehensive evaluation metrics

How it works

STGNN processes brain connectivity data through a two-stage architecture:
  1. Spatial analysis: Graph Neural Networks analyze functional connectivity matrices from fMRI scans, capturing brain network patterns at each time point
  2. Temporal modeling: Recurrent neural networks process sequences of brain connectivity embeddings across patient visits to predict future cognitive decline
The framework supports:
  • Multiple GNN architectures (GCN, GAT, GraphSAGE)
  • Flexible temporal models (LSTM, GRU, RNN)
  • TopK pooling for hierarchical graph representation
  • Time-aware prediction with temporal gap features
  • Pretrained encoder support for transfer learning

Get started

Installation

Install dependencies and set up your environment

Quick start

Run your first training in minutes

Core concepts

Learn about the architecture and methodology

Use cases

  • Clinical research: Predict which patients are at risk of converting from mild cognitive impairment to Alzheimer’s disease
  • Drug trials: Identify optimal patient cohorts for clinical trials based on predicted disease progression
  • Early intervention: Enable targeted early intervention strategies for at-risk patients
  • Biomarker discovery: Understand which brain connectivity patterns are most predictive of disease progression

Acknowledgments

This work was supported by NSF grants #CNS-2349663 and #OAC-2528533. This work used Indiana JetStream2 GPU at Indiana University through allocation NAIRR250048 from the Advanced Cyberinfrastructure Coordination Ecosystem: Services & Support (ACCESS) program.

Build docs developers (and LLMs) love