Overview
Temporal prediction models process sequences of brain graph embeddings to learn disease progression patterns. The STGNN framework supports three Recurrent Neural Network (RNN) architectures.LSTM
Default choiceLong Short-Term Memory - handles long-term dependencies with gating mechanisms
GRU
Gated Recurrent Unit - simpler than LSTM, faster training
RNN
Vanilla RNN - baseline architecture for comparison
LSTM: Long Short-Term Memory (Default)
Architecture
File:TemporalPredictor.py:5-86
TemporalPredictor.py:6-29
How LSTM Works
Forget Gate
Decides what information to discard from cell state:
- Value range: [0, 1] via sigmoid
- Meaning: 1 = keep everything, 0 = forget everything
- Clinical interpretation: Forget noisy baseline scans, retain stable patterns
Input Gate
Decides what new information to add to cell state:
- Update:
- Clinical interpretation: Incorporate new connectivity patterns indicating progression
Forward Pass Implementation
TemporalPredictor.py:40-86
Packed Sequences Optimization
Efficiency: Packed sequences skip computation on padding tokens, crucial when sequences vary from 1 to 10+ visits.
TemporalPredictor.py:68-70
Bidirectional LSTM
Whenbidirectional=True, the LSTM processes sequences in both directions:
Benefits:
- Forward pass: Early visits → Late visits (progression trajectory)
- Backward pass: Late visits → Early visits (retrospective context)
- Combined: Richer representation of entire sequence
--lstm_bidirectional True
Classifier MLP
TemporalPredictor.py:31-38
Advantages of LSTM
Long-Term Memory
Long-Term Memory
Cell state mechanism allows information to flow unchanged across many timesteps, preventing vanishing gradients. Critical for patients with 5+ visits spanning years.
Selective Remembering
Selective Remembering
Gating mechanisms learn what to remember (stable baseline) and what to forget (measurement noise). Important for noisy fMRI data.
Gradient Stability
Gradient Stability
Gates prevent gradient explosion/vanishing during backpropagation through time, enabling reliable training.
Clinical Relevance
Clinical Relevance
Can model both slow progression (retained in cell state) and sudden changes (captured by input gate), matching real disease dynamics.
GRU: Gated Recurrent Unit
Architecture
File:GRUPredictor.py:5-59
GRUPredictor.py:6-28
How GRU Works
Update Gate
Decides how much of the past to keep:
- Combines: LSTM’s forget and input gates into one
- Interpretation: 1 = keep old state, 0 = replace with new
Forward Pass
GRUPredictor.py:39-59
GRU vs. LSTM
- Advantages
- Disadvantages
Computational Efficiency:
- Fewer parameters (2 gates vs. 3 gates + cell state)
- Faster training and inference (~25-30% speedup)
- Lower memory usage
- Often matches LSTM accuracy on many tasks
- Sometimes generalizes better on smaller datasets
- Easier to understand and debug
- Fewer hyperparameters to tune
When to Use GRU
Best for:
- Shorter sequences (< 7 visits)
- Limited computational resources
- Faster experimentation
- When LSTM shows overfitting
RNN: Vanilla Recurrent Neural Network
Architecture
File:RNNPredictor.py:5-60
RNNPredictor.py:6-29
How Vanilla RNN Works
Simplest recurrent architecture:- No gating mechanisms: All information flows equally
- Simple update: New state is a function of old state and input
Forward Pass
RNNPredictor.py:40-60
Limitations
Short-Term Memory Only
Short-Term Memory Only
Information from early visits gets exponentially dampened by repeated tanh applications. By visit 5-6, baseline patterns are essentially forgotten.
Gradient Instability
Gradient Instability
Gradients can explode or vanish during training, requiring careful gradient clipping and learning rate tuning.
Poor Long Sequence Performance
Poor Long Sequence Performance
Struggles with sequences longer than 3-4 timesteps without tricks like gradient clipping or careful initialization.
When to Use Vanilla RNN
Best for:
- Baseline comparisons
- Very short sequences (2-3 visits)
- Ablation studies to demonstrate value of gating
- Educational purposes
Architecture Comparison
Performance Characteristics
| Architecture | Parameters | Speed | Memory | Long-Term Memory | Accuracy | Best Use Case |
|---|---|---|---|---|---|---|
| LSTM | High | Slow | High | Excellent | Highest | Default, long sequences |
| GRU | Medium | Fast | Medium | Good | High | Speed needed, shorter sequences |
| RNN | Low | Fastest | Low | Poor | Lowest | Baselines, very short sequences |
Detailed Comparison
- Parameter Count
- Training Speed
- Accuracy
For GRU:RNN:
hidden_dim=128, input_dim=512:LSTM:Configuration Guide
Model Selection
main.py:38
Hidden Dimension Tuning
- Small (32-64)
- Medium (128)
- Large (256+)
Use when:
- Small dataset (< 200 subjects)
- Short sequences (< 4 visits)
- Limited GPU memory
- Risk of overfitting
Layer Depth
Dropout: Only applied between layers when
num_layers > 1. For single-layer models, dropout is in the classifier MLP only.Bidirectional Setting
- Bidirectional: 2× parameters, richer representation
- Unidirectional: Faster, simpler, may be sufficient
Training Dynamics
Loss Function
Focal Loss is used to handle class imbalance:main.py:28-30
- Alpha (α)
- Gamma (γ)
Controls class weighting:
- High (0.9): Focus on minority class (converters)
- Low (0.5): Equal weighting
Gradient Clipping
Learning Rate Scheduling
Debugging Common Issues
Loss Exploding (NaN)
Loss Exploding (NaN)
Symptom: Loss becomes NaN after a few iterationsCauses:
- Vanilla RNN without gradient clipping
- Learning rate too high
- Input embeddings not normalized
Not Learning (Flat Loss)
Not Learning (Flat Loss)
Symptom: Loss stays constant, accuracy ~50%Causes:
- Frozen encoder with no fine-tuning
- Learning rate too low
- Sequence lengths not properly handled
Overfitting
Overfitting
Symptom: Training accuracy high, validation accuracy lowCauses:
- Model too complex for dataset size
- Insufficient regularization
- Data leakage (target visit in input)
Advanced: Custom Temporal Models
The modular design allows easy extension:- Accept
[B, T, D]input sequences - Handle variable lengths via
lengthsparameter - Output
[B, num_classes]logits - Support packed sequences for efficiency
Next Steps
GNN Encoding
How graph embeddings are created
Spatiotemporal Integration
Complete pipeline from graphs to predictions
Architecture Details
Full system architecture