Skip to main content

Overview

Spatiotemporal modeling captures both where changes occur in the brain (spatial) and when they occur (temporal), providing a comprehensive view of disease progression.

Spatial Feature Extraction

Brain Graphs as Structured Data

Each fMRI scan is represented as a graph where:
  • Nodes: Brain regions (typically 100-400 ROIs)
  • Edges: Functional connectivity between regions
  • Node Features: Statistical properties of each region

GNN Spatial Processing

The Graph Neural Network extracts spatial patterns through message passing:
1

Local Aggregation

Each node aggregates information from its connected neighbors:
# GraphSAGE aggregation
x = self.convs[i](x, edge_index)  # Aggregate neighbor features
This captures local connectivity patterns - which brain regions are communicating.
2

Multi-Hop Propagation

With multiple layers (default 2-3), information propagates across the graph:
  • Layer 1: Immediate neighbors
  • Layer 2: 2-hop neighborhood
  • Layer 3: 3-hop neighborhood
This captures network-level organization and community structure.
3

Hierarchical Pooling

TopK pooling progressively focuses on important regions:
x, edge_index, _, batch, perm, score = self.topk_pools[i](
    x, edge_index, batch=batch
)
This identifies critical nodes - regions most relevant for classification.
4

Global Representation

Dual pooling creates a graph-level feature vector:
x_mean = global_mean_pool(x, batch)  # Average pattern
x_max = global_max_pool(x, batch)    # Peak activations
x = torch.cat([x_mean, x_max], dim=1)  # [B, 512]
This produces a fixed-size embedding representing the entire brain state.
Spatial Encoding Output: Each brain scan → 512-dimensional vector capturing connectivity fingerprint

Temporal Sequence Construction

Multi-Visit Data Organization

The TemporalDataLoader organizes patient data into sequences:
def _group_by_subject(self):
    """Group dataset indices by subject ID."""
    subject_groups = {}
    for idx in self.subject_indices:
        data = self.dataset[idx]
        sid = getattr(data, 'subj_id', None)
        base_id = sid.split('_run')[0] if '_run' in sid else sid
        subject_groups.setdefault(base_id, []).append(idx)
    return subject_groups
Code reference: TemporalDataLoader.py:34-44

Sequence Construction Modes

Mode: exclude_target_visit=TruePurpose: Predict future state without data leakageLogic:
if len(subject_indices) > 1:
    # Use all visits except last as input
    input_indices = subject_indices[:-1]
    target_idx = subject_indices[-1]
    
    # Calculate time gap to predict
    last_input_data = self.dataset[input_indices[-1]]
    target_data = self.dataset[target_idx]
    time_to_predict = target_data.visit_months - last_input_data.visit_months
    
    # Label comes from target visit
    label = target_data.y.item()
else:
    # Single visit: predict default horizon ahead (6 months)
    input_indices = subject_indices
    time_to_predict = self.single_visit_horizon
    label = self.dataset[subject_indices[0]].y.item()
Code reference: TemporalDataLoader.py:92-132
This mode is crucial for true future prediction. Including the target visit in the input would allow the model to “see the future”.

Temporal Gap Normalization

Time intervals vary widely (1 month to 5+ years). Normalization stabilizes training:
from temporal_gap_processor import normalize_time_gaps

time_to_predict_normalized = normalize_time_gaps(
    np.array([time_to_predict]), 
    method=self.time_normalization  # 'log', 'minmax', 'buckets', 'raw'
)[0]
Code reference: TemporalDataLoader.py:134-138
normalized = np.log1p(time_gaps)  # log(1 + x)
Benefits:
  • Compresses large time ranges
  • 6 months → 1.95, 12 months → 2.56, 24 months → 3.22
  • Preserves relative ordering

Sequence Padding and Batching

Variable-length sequences must be padded for batch processing:
def _pad_sequences(self, sequences, labels):
    batch_size = len(sequences)
    lengths = torch.tensor([seq.size(0) for seq in sequences], dtype=torch.long)
    max_len = lengths.max().item()
    embed_dim = sequences[0].size(1)  # 512
    
    # Create padded tensor
    padded = torch.zeros(batch_size, max_len, embed_dim, device=self.device)
    for i, seq in enumerate(sequences):
        padded[i, :seq.size(0)] = seq  # Fill real data
    
    labels_tensor = torch.tensor(labels, dtype=torch.long, device=self.device)
    return padded, lengths.to(self.device), labels_tensor
Code reference: TemporalDataLoader.py:49-63
Sequence Lengths: The lengths tensor allows the RNN to use packed sequences, ignoring padding and processing only real timesteps.

Temporal Pattern Learning

RNN Processing of Sequences

The temporal model processes the sequence to learn progression patterns:
if lengths is not None:
    # Pack sequences to skip padding
    fused_packed = nn.utils.rnn.pack_padded_sequence(
        fused, lengths.cpu(), batch_first=True, enforce_sorted=False
    )
    output_packed, (h_n, c_n) = self.lstm(fused_packed)
else:
    output, (h_n, c_n) = self.lstm(fused)

# Extract final hidden state (summary of entire sequence)
if self.bidirectional:
    final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
else:
    final_hidden = h_n[-1]
Code reference: TemporalPredictor.py:66-82

What the RNN Learns

The RNN learns whether connectivity is:
  • Stable: Consistent patterns across visits
  • Gradually declining: Progressive degradation
  • Rapidly deteriorating: Fast conversion trajectory
  • Fluctuating: Noisy or recovering patterns
The sequential processing captures:
  • Direction of change (improving vs. worsening)
  • Rate of change (slope of progression)
  • Acceleration (is decline speeding up?)
LSTM’s gating mechanisms remember:
  • Baseline connectivity from early visits
  • Cumulative changes over time
  • Critical transition points

Time-Aware Enhancement

Optional temporal gap features explicitly inform the model about prediction horizon:

GNN-Level Time Integration

if self.use_time_features and time_to_predict is not None:
    # Project normalized time gap to 32D
    time_embedding = self.time_projection(time_to_predict)  # [B, 32]
    
    # Combine with graph embedding
    combined = torch.cat([graph_embedding, time_embedding], dim=1)  # [B, 544]
    final_embedding = self.fusion_layer(combined)  # [B, 512]
Code reference: model.py:120-132

Design Rationale

Why Small Time Dimension?

Time features use only 32D (vs. 512D for graphs) to prevent temporal information from overwhelming spatial patterns. The brain connectivity patterns should drive predictions, with time serving as context.

Time Feature Flow

Integration Example

Complete pipeline for a single patient:
1

Data Collection

Patient has 4 visits: Months 0, 6, 12, 24Labels: Normal, Normal, Normal, MCI (converted at Month 24)
2

Sequence Construction

With exclude_target_visit=True:
  • Input: Visits at months 0, 6, 12 (3 timesteps)
  • Target: MCI label from month 24
  • Time gap: 12 months (24 - 12)
3

Spatial Encoding

Each input visit is encoded by GNN:
  • Visit 0 → Embedding 1 [512D]
  • Visit 6 → Embedding 2 [512D]
  • Visit 12 → Embedding 3 [512D]
Optional: Time gap (12 months) fused into each embedding
4

Temporal Processing

LSTM processes sequence [3, 512]:
  • Learns progression from Normal → Normal → Normal
  • Predicts future state 12 months ahead
  • Output: Hidden state [128D]
5

Classification

MLP classifier:
  • Input: Hidden state [128D]
  • Output: Logits [2D] → [Normal, MCI]
  • Prediction: MCI (correct!)

Performance Optimization

Batched Embedding Computation

Instead of encoding each visit separately:
# Collect ALL graphs from ALL subjects in batch
all_data = []  # e.g., 3+4+2+5 = 14 graphs from 4 subjects
for subject in batch_subjects:
    for visit in subject_visits:
        all_data.append(graph_data)

# Single forward pass through GNN
big_batch = Batch.from_data_list(all_data)
all_embeddings = encoder(big_batch.x, big_batch.edge_index, big_batch.batch)

# Reshape back to per-subject sequences
for subject_idx in range(num_subjects):
    subject_sequence = [all_embeddings[i] for i in subject_indices]
    batch_sequences.append(torch.stack(subject_sequence))
Code reference: TemporalDataLoader.py:194-224
Speedup: Processing 14 graphs in a single batch is ~10x faster than 14 separate forward passes due to GPU parallelization.

Practical Considerations

Missing Visits

Handled automatically: Sequences include only available visits. The RNN processes variable-length sequences naturally via packed sequences.

Irregular Intervals

Normalized: Log transformation handles intervals from 1 month to 5+ years. Optional time features explicitly model intervals.

Single Visits

Prediction horizon: Default 6 months ahead, configurable via --single_visit_horizon. Uses current state to predict future.

Sequence Length Limits

Truncation: Set via --max_visits (default 10). Keeps most recent visits when patients have excessive follow-up.

Evaluation Strategy

Subject-Level Splitting: Train/val/test splits are done at the subject level, not visit level. All visits from a patient stay together, preventing data leakage.
def get_kfold_splits(dataset, num_folds=5, seed=42):
    # Group by subject
    subject_labels = {}
    for subj_id, graphs in dataset.subject_graph_dict.items():
        subject_labels[subj_id] = graphs[0].y.item()  # Use first visit label
    
    # Stratified split on subjects
    skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=seed)
    for train_val_idx, test_idx in skf.split(subjects, labels):
        # Further split train_val into train and validation
        ...
Code reference: main.py:117-150

Next Steps

GNN Layers

Spatial encoding details

RNN Architectures

Temporal modeling options

Full Architecture

Complete system overview

Build docs developers (and LLMs) love