Overview
Spatiotemporal modeling captures both where changes occur in the brain (spatial) and when they occur (temporal), providing a comprehensive view of disease progression.Spatial Feature Extraction
Brain Graphs as Structured Data
Each fMRI scan is represented as a graph where:- Nodes: Brain regions (typically 100-400 ROIs)
- Edges: Functional connectivity between regions
- Node Features: Statistical properties of each region
GNN Spatial Processing
The Graph Neural Network extracts spatial patterns through message passing:Local Aggregation
Each node aggregates information from its connected neighbors:This captures local connectivity patterns - which brain regions are communicating.
Multi-Hop Propagation
With multiple layers (default 2-3), information propagates across the graph:
- Layer 1: Immediate neighbors
- Layer 2: 2-hop neighborhood
- Layer 3: 3-hop neighborhood
Hierarchical Pooling
TopK pooling progressively focuses on important regions:This identifies critical nodes - regions most relevant for classification.
Spatial Encoding Output: Each brain scan → 512-dimensional vector capturing connectivity fingerprint
Temporal Sequence Construction
Multi-Visit Data Organization
TheTemporalDataLoader organizes patient data into sequences:
TemporalDataLoader.py:34-44
Sequence Construction Modes
- Exclude Target (Default)
- Include All Visits
Mode: Code reference:
exclude_target_visit=TruePurpose: Predict future state without data leakageLogic:TemporalDataLoader.py:92-132Temporal Gap Normalization
Time intervals vary widely (1 month to 5+ years). Normalization stabilizes training:TemporalDataLoader.py:134-138
- Log Normalization
- Min-Max Scaling
- Bucketing
- Compresses large time ranges
- 6 months → 1.95, 12 months → 2.56, 24 months → 3.22
- Preserves relative ordering
Sequence Padding and Batching
Variable-length sequences must be padded for batch processing:TemporalDataLoader.py:49-63
Sequence Lengths: The
lengths tensor allows the RNN to use packed sequences, ignoring padding and processing only real timesteps.Temporal Pattern Learning
RNN Processing of Sequences
The temporal model processes the sequence to learn progression patterns:TemporalPredictor.py:66-82
What the RNN Learns
Trajectory Patterns
Trajectory Patterns
The RNN learns whether connectivity is:
- Stable: Consistent patterns across visits
- Gradually declining: Progressive degradation
- Rapidly deteriorating: Fast conversion trajectory
- Fluctuating: Noisy or recovering patterns
Visit-to-Visit Changes
Visit-to-Visit Changes
The sequential processing captures:
- Direction of change (improving vs. worsening)
- Rate of change (slope of progression)
- Acceleration (is decline speeding up?)
Long-Term Dependencies
Long-Term Dependencies
LSTM’s gating mechanisms remember:
- Baseline connectivity from early visits
- Cumulative changes over time
- Critical transition points
Time-Aware Enhancement
Optional temporal gap features explicitly inform the model about prediction horizon:GNN-Level Time Integration
model.py:120-132
Design Rationale
Why Small Time Dimension?
Time features use only 32D (vs. 512D for graphs) to prevent temporal information from overwhelming spatial patterns. The brain connectivity patterns should drive predictions, with time serving as context.
Time Feature Flow
Integration Example
Complete pipeline for a single patient:Data Collection
Patient has 4 visits: Months 0, 6, 12, 24Labels: Normal, Normal, Normal, MCI (converted at Month 24)
Sequence Construction
With
exclude_target_visit=True:- Input: Visits at months 0, 6, 12 (3 timesteps)
- Target: MCI label from month 24
- Time gap: 12 months (24 - 12)
Spatial Encoding
Each input visit is encoded by GNN:
- Visit 0 → Embedding 1 [512D]
- Visit 6 → Embedding 2 [512D]
- Visit 12 → Embedding 3 [512D]
Temporal Processing
LSTM processes sequence [3, 512]:
- Learns progression from Normal → Normal → Normal
- Predicts future state 12 months ahead
- Output: Hidden state [128D]
Performance Optimization
Batched Embedding Computation
Instead of encoding each visit separately:TemporalDataLoader.py:194-224
Speedup: Processing 14 graphs in a single batch is ~10x faster than 14 separate forward passes due to GPU parallelization.
Practical Considerations
Missing Visits
Handled automatically: Sequences include only available visits. The RNN processes variable-length sequences naturally via packed sequences.
Irregular Intervals
Normalized: Log transformation handles intervals from 1 month to 5+ years. Optional time features explicitly model intervals.
Single Visits
Prediction horizon: Default 6 months ahead, configurable via
--single_visit_horizon. Uses current state to predict future.Sequence Length Limits
Truncation: Set via
--max_visits (default 10). Keeps most recent visits when patients have excessive follow-up.Evaluation Strategy
main.py:117-150
Next Steps
GNN Layers
Spatial encoding details
RNN Architectures
Temporal modeling options
Full Architecture
Complete system overview