Introduction
The Spatiotemporal Graph Neural Network (STGNN) framework combines graph-based spatial analysis with temporal sequence modeling to predict Alzheimer’s disease progression from longitudinal fMRI brain connectivity data.Core Approach
The STGNN architecture addresses the challenge of understanding disease progression by modeling both:Spatial Relationships
Brain region connectivity patterns captured through graph neural networks
Temporal Evolution
How these patterns change over time using recurrent neural networks
Architecture Pipeline
Processing Flow
- Graph Construction: Brain regions become nodes, functional connectivity becomes edges
- Spatial Encoding: Graph Neural Networks extract connectivity patterns from each scan
- Temporal Sequencing: Multiple visits per patient create time-ordered sequences
- Temporal Modeling: Recurrent networks learn progression patterns across visits
- Classification: Final prediction of disease state or conversion risk
Key Features
Multi-Visit Temporal Modeling
Multi-Visit Temporal Modeling
The system processes multiple brain scans from the same patient over time, learning how connectivity patterns evolve. For patients with multiple visits, it uses all visits except the last as input to predict the final state (preventing data leakage). For single-visit patients, it predicts a configurable time horizon ahead (default 6 months).
Time-Aware Predictions
Time-Aware Predictions
Optional temporal gap features can be incorporated into the GNN encoder, allowing the model to explicitly account for varying time intervals between visits. Time gaps are normalized using methods like log transformation to handle the wide range of follow-up durations.
Flexible GNN Architectures
Flexible GNN Architectures
Supports multiple graph convolution types (GraphSAGE, GCN, GAT) with configurable depth (2-5 layers), hidden dimensions (default 256), and activation functions (ReLU, ELU, GELU, LeakyReLU).
Multiple RNN Options
Multiple RNN Options
Three temporal predictor architectures available: LSTM (default), GRU, and vanilla RNN, all supporting bidirectional processing for richer temporal context.
Data Flow Dimensions
Understanding tensor dimensions through the pipeline:| Stage | Dimension | Description |
|---|---|---|
| Input Graph | [N, F] | N brain regions, F features per region |
| GNN Output | [B, 512] | Batch size B, 512D embeddings (256×2 from mean+max pooling) |
| Temporal Sequence | [B, T, 512] | T visits per patient |
| RNN Hidden State | [B, H] | Hidden dimension H (default 64, ×2 if bidirectional) |
| Classification Logits | [B, 2] | Binary classification output |
Clinical Application
The model predicts:- Current cognitive state: Normal vs. impaired cognition
- Future conversion risk: Likelihood of progression to Alzheimer’s disease
- Temporal patterns: How quickly connectivity patterns are deteriorating
Temporal Gap Handling: The system intelligently handles varying follow-up intervals. When
exclude_target_visit=True, it learns to predict outcomes at specific future timepoints by incorporating normalized time-to-prediction features.Training Strategy
Encoder Pre-training (Optional)
The GNN encoder can be pre-trained on static graph classification before temporal training
Temporal Fine-tuning
The full spatiotemporal model is trained end-to-end, optionally freezing the GNN encoder to preserve learned spatial representations
Performance Considerations
Optimization Techniques
- Packed Sequences: Variable-length sequences are efficiently packed to avoid wasted computation on padding
- TopK Pooling: Hierarchical graph coarsening retains the most salient nodes (default 30% minimum retention)
- Focal Loss: Addresses class imbalance by focusing on hard-to-classify examples (configurable α and γ)
Next Steps
System Architecture
Deep dive into the two-stage architecture
Spatiotemporal Modeling
Learn how spatial and temporal features combine
Graph Neural Networks
Explore GNN layer types and configurations
Temporal Prediction
Understand LSTM/GRU/RNN architectures