Skip to main content

Overview

Temporal prediction models process sequences of brain graph embeddings to learn disease progression patterns. The STGNN framework supports three Recurrent Neural Network (RNN) architectures.

LSTM

Default choiceLong Short-Term Memory - handles long-term dependencies with gating mechanisms

GRU

Gated Recurrent Unit - simpler than LSTM, faster training

RNN

Vanilla RNN - baseline architecture for comparison

LSTM: Long Short-Term Memory (Default)

Architecture

File: TemporalPredictor.py:5-86
class TemporalTabGNNClassifier(nn.Module):
    def __init__(
        self,
        graph_emb_dim: int = 256,      # GNN output dimension (×2 = 512)
        tab_emb_dim: int = 64,         # Optional tabular features
        hidden_dim: int = 128,         # LSTM hidden state size
        num_layers: int = 1,           # LSTM depth
        dropout: float = 0.3,
        bidirectional: bool = False,   # Use BiLSTM?
        num_classes: int = 2
    ):
        self.lstm = nn.LSTM(
            input_size=graph_emb_dim + tab_emb_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
Code reference: TemporalPredictor.py:6-29

How LSTM Works

1

Cell State Pipeline

LSTM maintains a cell state CtC_t that flows through time with minimal modifications:
2

Forget Gate

Decides what information to discard from cell state:ft=σ(Wf[ht1,xt]+bf)f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)
  • Value range: [0, 1] via sigmoid
  • Meaning: 1 = keep everything, 0 = forget everything
  • Clinical interpretation: Forget noisy baseline scans, retain stable patterns
3

Input Gate

Decides what new information to add to cell state:it=σ(Wi[ht1,xt]+bi)i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)C~t=tanh(WC[ht1,xt]+bC)\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)
  • Update: Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t
  • Clinical interpretation: Incorporate new connectivity patterns indicating progression
4

Output Gate

Decides what to output as hidden state:ot=σ(Wo[ht1,xt]+bo)o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)ht=ottanh(Ct)h_t = o_t \odot \tanh(C_t)
  • Clinical interpretation: Generate prediction-relevant summary of progression

Forward Pass Implementation

def forward(self, graph_seq, tab_seq=None, lengths=None, mask=None):
    # Step 1: Concatenate modalities
    if tab_seq is None or self.tab_emb_dim == 0:
        fused = graph_seq  # [B, T, 512]
    else:
        fused = torch.cat([graph_seq, tab_seq], dim=-1)  # [B, T, 576]
    
    # Step 2: Process through LSTM with packed sequences
    if lengths is not None:
        # Pack to skip padding tokens (efficiency optimization)
        fused_packed = nn.utils.rnn.pack_padded_sequence(
            fused, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output_packed, (h_n, c_n) = self.lstm(fused_packed)
    else:
        output, (h_n, c_n) = self.lstm(fused)
    
    # Step 3: Extract final hidden state
    if self.bidirectional:
        # Concatenate forward and backward hidden states
        final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)  # [B, 256]
    else:
        final_hidden = h_n[-1]  # [B, 128]
    
    # Step 4: Classification MLP
    logits = self.classifier(final_hidden)  # [B, 2]
    return logits
Code reference: TemporalPredictor.py:40-86

Packed Sequences Optimization

Efficiency: Packed sequences skip computation on padding tokens, crucial when sequences vary from 1 to 10+ visits.
# Example: Batch with sequences of length [3, 7, 2, 5]
# Without packing: Process 7×4 = 28 timesteps (including padding)
# With packing: Process 3+7+2+5 = 17 timesteps (37% faster!)

fused_packed = nn.utils.rnn.pack_padded_sequence(
    fused,                # [B, T, 512]
    lengths.cpu(),        # [B] actual lengths
    batch_first=True,
    enforce_sorted=False  # Automatically sorts and unsorts
)
Code reference: TemporalPredictor.py:68-70

Bidirectional LSTM

When bidirectional=True, the LSTM processes sequences in both directions: Benefits:
  • Forward pass: Early visits → Late visits (progression trajectory)
  • Backward pass: Late visits → Early visits (retrospective context)
  • Combined: Richer representation of entire sequence
Default: --lstm_bidirectional True

Classifier MLP

lstm_output_dim = hidden_dim * (2 if bidirectional else 1)

self.classifier = nn.Sequential(
    nn.Linear(lstm_output_dim, 64),  # 256 → 64 (if bidirectional)
    nn.ReLU(),
    nn.Dropout(dropout),              # Default 0.3
    nn.Linear(64, num_classes)        # 64 → 2 (Normal vs. MCI)
)
Code reference: TemporalPredictor.py:31-38

Advantages of LSTM

Cell state mechanism allows information to flow unchanged across many timesteps, preventing vanishing gradients. Critical for patients with 5+ visits spanning years.
Gating mechanisms learn what to remember (stable baseline) and what to forget (measurement noise). Important for noisy fMRI data.
Gates prevent gradient explosion/vanishing during backpropagation through time, enabling reliable training.
Can model both slow progression (retained in cell state) and sudden changes (captured by input gate), matching real disease dynamics.

GRU: Gated Recurrent Unit

Architecture

File: GRUPredictor.py:5-59
class GRUPredictor(nn.Module):
    def __init__(self, graph_emb_dim=256, tab_emb_dim=64, 
                 hidden_dim=128, num_layers=1, dropout=0.3,
                 bidirectional=False, num_classes=2):
        
        self.gru = nn.GRU(
            input_size=graph_emb_dim + tab_emb_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
Code reference: GRUPredictor.py:6-28

How GRU Works

1

Update Gate

Decides how much of the past to keep:zt=σ(Wz[ht1,xt])z_t = \sigma(W_z \cdot [h_{t-1}, x_t])
  • Combines: LSTM’s forget and input gates into one
  • Interpretation: 1 = keep old state, 0 = replace with new
2

Reset Gate

Decides how much of the past to ignore when computing new state:rt=σ(Wr[ht1,xt])r_t = \sigma(W_r \cdot [h_{t-1}, x_t])
3

New State Computation

h~t=tanh(W[rtht1,xt])\tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t])
4

Final State Update

ht=(1zt)ht1+zth~th_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
  • No separate cell state: Simpler than LSTM

Forward Pass

def forward(self, graph_seq, tab_seq=None, lengths=None, mask=None):
    # Concatenate modalities
    if tab_seq is None or self.tab_emb_dim == 0:
        fused = graph_seq
    else:
        fused = torch.cat([graph_seq, tab_seq], dim=-1)
    
    # Process with GRU (packed sequences supported)
    if lengths is not None:
        fused_packed = nn.utils.rnn.pack_padded_sequence(
            fused, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output_packed, h_n = self.gru(fused_packed)  # Note: no cell state
    else:
        output, h_n = self.gru(fused)
    
    # Extract final hidden state
    if self.bidirectional:
        final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
    else:
        final_hidden = h_n[-1]
    
    # Classify
    logits = self.classifier(final_hidden)
    return logits
Code reference: GRUPredictor.py:39-59

GRU vs. LSTM

Computational Efficiency:
  • Fewer parameters (2 gates vs. 3 gates + cell state)
  • Faster training and inference (~25-30% speedup)
  • Lower memory usage
Performance:
  • Often matches LSTM accuracy on many tasks
  • Sometimes generalizes better on smaller datasets
Simplicity:
  • Easier to understand and debug
  • Fewer hyperparameters to tune

When to Use GRU

Best for:
  • Shorter sequences (< 7 visits)
  • Limited computational resources
  • Faster experimentation
  • When LSTM shows overfitting

RNN: Vanilla Recurrent Neural Network

Architecture

File: RNNPredictor.py:5-60
class RNNPredictor(nn.Module):
    def __init__(self, graph_emb_dim=256, tab_emb_dim=64,
                 hidden_dim=128, num_layers=1, dropout=0.3,
                 bidirectional=False, num_classes=2):
        
        self.rnn = nn.RNN(
            input_size=graph_emb_dim + tab_emb_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
            nonlinearity='tanh'  # or 'relu'
        )
Code reference: RNNPredictor.py:6-29

How Vanilla RNN Works

Simplest recurrent architecture: ht=tanh(Whhht1+Wxhxt+bh)h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h)
  • No gating mechanisms: All information flows equally
  • Simple update: New state is a function of old state and input

Forward Pass

def forward(self, graph_seq, tab_seq=None, lengths=None, mask=None):
    # Concatenate modalities
    if tab_seq is None or self.tab_emb_dim == 0:
        fused = graph_seq
    else:
        fused = torch.cat([graph_seq, tab_seq], dim=-1)
    
    # Process with RNN
    if lengths is not None:
        fused_packed = nn.utils.rnn.pack_padded_sequence(
            fused, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output_packed, h_n = self.rnn(fused_packed)
    else:
        output, h_n = self.rnn(fused)
    
    # Extract final hidden state
    if self.bidirectional:
        final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
    else:
        final_hidden = h_n[-1]
    
    # Classify
    logits = self.classifier(final_hidden)
    return logits
Code reference: RNNPredictor.py:40-60

Limitations

Vanishing Gradients: Without gating mechanisms, gradients can vanish exponentially during backpropagation through time, making it hard to learn long-term dependencies.
Information from early visits gets exponentially dampened by repeated tanh applications. By visit 5-6, baseline patterns are essentially forgotten.
Gradients can explode or vanish during training, requiring careful gradient clipping and learning rate tuning.
Struggles with sequences longer than 3-4 timesteps without tricks like gradient clipping or careful initialization.

When to Use Vanilla RNN

Best for:
  • Baseline comparisons
  • Very short sequences (2-3 visits)
  • Ablation studies to demonstrate value of gating
  • Educational purposes

Architecture Comparison

Performance Characteristics

ArchitectureParametersSpeedMemoryLong-Term MemoryAccuracyBest Use Case
LSTMHighSlowHighExcellentHighestDefault, long sequences
GRUMediumFastMediumGoodHighSpeed needed, shorter sequences
RNNLowFastestLowPoorLowestBaselines, very short sequences

Detailed Comparison

For hidden_dim=128, input_dim=512:LSTM:
4 × (hidden_dim × (input_dim + hidden_dim + 1))
= 4 × (128 × (512 + 128 + 1))
= 328,192 parameters
GRU:
3 × (hidden_dim × (input_dim + hidden_dim + 1))
= 3 × (128 × (512 + 128 + 1))
= 246,144 parameters (25% fewer)
RNN:
hidden_dim × (input_dim + hidden_dim + 1)
= 128 × (512 + 128 + 1)
= 82,048 parameters (75% fewer)

Configuration Guide

Model Selection

# LSTM (default, recommended)
python main.py --model_type LSTM --lstm_hidden_dim 64 --lstm_bidirectional True

# GRU (faster alternative)
python main.py --model_type GRU --lstm_hidden_dim 64 --lstm_bidirectional True

# RNN (baseline)
python main.py --model_type RNN --lstm_hidden_dim 64 --lstm_bidirectional True
Code reference: main.py:38

Hidden Dimension Tuning

Use when:
  • Small dataset (< 200 subjects)
  • Short sequences (< 4 visits)
  • Limited GPU memory
  • Risk of overfitting
--lstm_hidden_dim 32  # Very constrained
--lstm_hidden_dim 64  # Default, good balance

Layer Depth

# Single layer (default, recommended)
--lstm_num_layers 1

# Two layers (more capacity, risk of overfitting)
--lstm_num_layers 2
Dropout: Only applied between layers when num_layers > 1. For single-layer models, dropout is in the classifier MLP only.

Bidirectional Setting

# Bidirectional (default, recommended)
--lstm_bidirectional True

# Unidirectional (faster, less memory)
--lstm_bidirectional False
Trade-off:
  • Bidirectional: 2× parameters, richer representation
  • Unidirectional: Faster, simpler, may be sufficient

Training Dynamics

Loss Function

Focal Loss is used to handle class imbalance:
from FocalLoss import FocalLoss

criterion = FocalLoss(
    alpha=0.90,      # Weight for minority class
    gamma=3.0,       # Focusing parameter
    label_smoothing=0.05
)
Code reference: main.py:28-30
Controls class weighting:
  • High (0.9): Focus on minority class (converters)
  • Low (0.5): Equal weighting
FL(pt)=αt(1pt)γlog(pt)\text{FL}(p_t) = -\alpha_t (1-p_t)^\gamma \log(p_t)

Gradient Clipping

# Prevent gradient explosion in RNNs
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Essential for: Vanilla RNNs, deep LSTM/GRU (3+ layers)

Learning Rate Scheduling

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=10
)
Reduces learning rate when validation loss plateaus.

Debugging Common Issues

Symptom: Loss becomes NaN after a few iterationsCauses:
  • Vanilla RNN without gradient clipping
  • Learning rate too high
  • Input embeddings not normalized
Solutions:
# Add gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Lower learning rate
--lr 0.0001

# Use GRU/LSTM instead of RNN
--model_type LSTM
Symptom: Loss stays constant, accuracy ~50%Causes:
  • Frozen encoder with no fine-tuning
  • Learning rate too low
  • Sequence lengths not properly handled
Solutions:
# Don't freeze encoder (remove flag)
# (no --freeze_encoder flag)

# Increase learning rate
--lr 0.001

# Check data loader is providing sequences
print(f"Sequence lengths: {lengths}")
Symptom: Training accuracy high, validation accuracy lowCauses:
  • Model too complex for dataset size
  • Insufficient regularization
  • Data leakage (target visit in input)
Solutions:
# Reduce model capacity
--lstm_hidden_dim 32
--lstm_num_layers 1

# Increase regularization
--dropout 0.5

# Ensure target exclusion
--exclude_target_visit True

Advanced: Custom Temporal Models

The modular design allows easy extension:
class CustomTemporalModel(nn.Module):
    def __init__(self, graph_emb_dim=256, ...):
        super().__init__()
        # Your architecture here
        self.temporal_module = ...
        self.classifier = ...
    
    def forward(self, graph_seq, lengths=None, ...):
        # Must accept same inputs as LSTM/GRU/RNN
        x = self.temporal_module(graph_seq)
        logits = self.classifier(x)
        return logits
Requirements:
  1. Accept [B, T, D] input sequences
  2. Handle variable lengths via lengths parameter
  3. Output [B, num_classes] logits
  4. Support packed sequences for efficiency

Next Steps

GNN Encoding

How graph embeddings are created

Spatiotemporal Integration

Complete pipeline from graphs to predictions

Architecture Details

Full system architecture

Build docs developers (and LLMs) love