Skip to main content

Introduction

The Spatiotemporal Graph Neural Network (STGNN) framework combines graph-based spatial analysis with temporal sequence modeling to predict Alzheimer’s disease progression from longitudinal fMRI brain connectivity data.

Core Approach

The STGNN architecture addresses the challenge of understanding disease progression by modeling both:

Spatial Relationships

Brain region connectivity patterns captured through graph neural networks

Temporal Evolution

How these patterns change over time using recurrent neural networks

Architecture Pipeline

Processing Flow

  1. Graph Construction: Brain regions become nodes, functional connectivity becomes edges
  2. Spatial Encoding: Graph Neural Networks extract connectivity patterns from each scan
  3. Temporal Sequencing: Multiple visits per patient create time-ordered sequences
  4. Temporal Modeling: Recurrent networks learn progression patterns across visits
  5. Classification: Final prediction of disease state or conversion risk

Key Features

The system processes multiple brain scans from the same patient over time, learning how connectivity patterns evolve. For patients with multiple visits, it uses all visits except the last as input to predict the final state (preventing data leakage). For single-visit patients, it predicts a configurable time horizon ahead (default 6 months).
Optional temporal gap features can be incorporated into the GNN encoder, allowing the model to explicitly account for varying time intervals between visits. Time gaps are normalized using methods like log transformation to handle the wide range of follow-up durations.
Supports multiple graph convolution types (GraphSAGE, GCN, GAT) with configurable depth (2-5 layers), hidden dimensions (default 256), and activation functions (ReLU, ELU, GELU, LeakyReLU).
Three temporal predictor architectures available: LSTM (default), GRU, and vanilla RNN, all supporting bidirectional processing for richer temporal context.

Data Flow Dimensions

Understanding tensor dimensions through the pipeline:
StageDimensionDescription
Input Graph[N, F]N brain regions, F features per region
GNN Output[B, 512]Batch size B, 512D embeddings (256×2 from mean+max pooling)
Temporal Sequence[B, T, 512]T visits per patient
RNN Hidden State[B, H]Hidden dimension H (default 64, ×2 if bidirectional)
Classification Logits[B, 2]Binary classification output

Clinical Application

The model predicts:
  • Current cognitive state: Normal vs. impaired cognition
  • Future conversion risk: Likelihood of progression to Alzheimer’s disease
  • Temporal patterns: How quickly connectivity patterns are deteriorating
Temporal Gap Handling: The system intelligently handles varying follow-up intervals. When exclude_target_visit=True, it learns to predict outcomes at specific future timepoints by incorporating normalized time-to-prediction features.

Training Strategy

1

Encoder Pre-training (Optional)

The GNN encoder can be pre-trained on static graph classification before temporal training
2

Temporal Fine-tuning

The full spatiotemporal model is trained end-to-end, optionally freezing the GNN encoder to preserve learned spatial representations
3

Cross-Validation

5-fold stratified cross-validation at the subject level ensures robust evaluation

Performance Considerations

Memory Efficiency: The system uses batched embedding computation - all graphs in a batch are processed in a single forward pass through the GNN encoder, then reshaped into per-subject sequences for the RNN.

Optimization Techniques

  • Packed Sequences: Variable-length sequences are efficiently packed to avoid wasted computation on padding
  • TopK Pooling: Hierarchical graph coarsening retains the most salient nodes (default 30% minimum retention)
  • Focal Loss: Addresses class imbalance by focusing on hard-to-classify examples (configurable α and γ)

Next Steps

System Architecture

Deep dive into the two-stage architecture

Spatiotemporal Modeling

Learn how spatial and temporal features combine

Graph Neural Networks

Explore GNN layer types and configurations

Temporal Prediction

Understand LSTM/GRU/RNN architectures

Build docs developers (and LLMs) love