Skip to main content

Overview

TemporalDataLoader is a custom data loader that groups brain graphs by subject and creates padded temporal sequences for processing with recurrent neural networks (LSTMs). It handles variable-length sequences, temporal gap normalization, and batched graph encoding.

Class Definition

class TemporalDataLoader

Constructor

TemporalDataLoader(
    dataset,
    subject_indices,
    encoder,
    device,
    batch_size=8,
    shuffle=True,
    seed=42,
    exclude_target_visit=True,
    time_normalization='log',
    single_visit_horizon=6
)

Parameters

dataset
torch_geometric.data.Dataset
required
The dataset containing brain graphs with temporal information. Should have graphs with subj_id, visit_months, and months_to_next attributes.
subject_indices
list[int]
required
List of dataset indices to include in this data loader (e.g., train, validation, or test split indices).
encoder
torch.nn.Module
required
Graph encoder model (GNN) used to compute graph embeddings. Should accept (x, edge_index, batch, time_features) as input.
device
torch.device
required
Device to place tensors on (e.g., torch.device('cuda') or torch.device('cpu')).
batch_size
int
default:"8"
Number of subjects (sequences) per batch.
shuffle
bool
default:"True"
Whether to shuffle subjects at the start of each epoch.
seed
int
default:"42"
Random seed for reproducible shuffling.
exclude_target_visit
bool
default:"True"
If True, excludes the last visit from input sequence and uses it as prediction target. If False, uses all visits as input.
time_normalization
str
default:"'log'"
Method for normalizing temporal gaps. Options include 'log', 'linear', etc. (handled by temporal_gap_processor.normalize_time_gaps).
single_visit_horizon
int
default:"6"
Default prediction horizon (in months) for subjects with only one visit.

Attributes

subject_data
dict
Dictionary mapping base subject IDs to lists of dataset indices (grouped by subject).
subjects
list
List of unique subject IDs in the data loader.

Methods

__iter__

Iterates over temporal sequences in batches.
__iter__() -> Iterator[dict]
Yields:
  • dict: Batch dictionary containing:
    • graph_seq (torch.Tensor): Padded graph embeddings of shape (batch_size, max_seq_len, embed_dim)
    • lengths (torch.Tensor): Actual sequence lengths for each subject, shape (batch_size,)
    • labels (torch.Tensor): Target labels for each subject, shape (batch_size,)
    • time_gaps (torch.Tensor): Normalized temporal gaps to predict, shape (batch_size,)
    • batch_size (int): Actual number of sequences in the batch
Behavior:
  • Groups graphs by subject
  • Computes embeddings for all graphs in a batch using the encoder
  • Handles temporal gap calculation based on exclude_target_visit setting
  • Creates padded sequences for variable-length inputs

__len__

Returns the number of batches per epoch.
__len__() -> int
Returns:
  • int: Number of batches (ceiling division of number of subjects by batch size)

_group_by_subject

Groups dataset indices by subject ID.
_group_by_subject() -> dict
Returns:
  • dict: Dictionary mapping base subject IDs to lists of their graph indices
Implementation Details:
  • Extracts base subject ID by splitting on '_run'
  • Handles cases where subj_id attribute may not exist

_pad_sequences

Pads sequences to the same length and creates batch tensors.
_pad_sequences(
    sequences,
    labels
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Parameters:
  • sequences (list[torch.Tensor]): List of embedding sequences (variable length)
  • labels (list[int]): List of corresponding labels
Returns:
  • tuple: (padded_sequences, lengths, labels_tensor)
    • padded_sequences: Tensor of shape (batch_size, max_len, embed_dim)
    • lengths: Tensor of actual sequence lengths
    • labels_tensor: Tensor of labels

Usage Example

import torch
from torch_geometric.loader import DataLoader
from TemporalDataLoader import TemporalDataLoader
from your_model import GraphEncoder, TemporalLSTM

# Assume dataset is already loaded (FC_ADNIDataset or DFC_ADNIDataset)
# Split dataset into train/val/test
train_indices = list(range(0, 800))
val_indices = list(range(800, 900))
test_indices = list(range(900, 1000))

# Initialize graph encoder
encoder = GraphEncoder(
    num_features=dataset.num_features,
    hidden_dim=128,
    embed_dim=64
).to(device)

# Create temporal data loaders
train_loader = TemporalDataLoader(
    dataset=dataset,
    subject_indices=train_indices,
    encoder=encoder,
    device=device,
    batch_size=16,
    shuffle=True,
    seed=42,
    exclude_target_visit=True,
    time_normalization='log',
    single_visit_horizon=6
)

val_loader = TemporalDataLoader(
    dataset=dataset,
    subject_indices=val_indices,
    encoder=encoder,
    device=device,
    batch_size=16,
    shuffle=False,
    exclude_target_visit=True
)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

# Initialize LSTM model
lstm_model = TemporalLSTM(
    input_dim=64,
    hidden_dim=128,
    num_classes=3,
    num_layers=2
).to(device)

# Training loop
for epoch in range(num_epochs):
    lstm_model.train()
    encoder.eval()  # Encoder is typically frozen or trained separately
    
    for batch in train_loader:
        graph_seq = batch['graph_seq']  # (batch_size, max_seq_len, 64)
        lengths = batch['lengths']      # (batch_size,)
        labels = batch['labels']        # (batch_size,)
        time_gaps = batch['time_gaps']  # (batch_size,)
        
        # Forward pass
        logits = lstm_model(graph_seq, lengths, time_gaps)
        loss = criterion(logits, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Validation
    lstm_model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in val_loader:
            graph_seq = batch['graph_seq']
            lengths = batch['lengths']
            labels = batch['labels']
            time_gaps = batch['time_gaps']
            
            logits = lstm_model(graph_seq, lengths, time_gaps)
            loss = criterion(logits, labels)
            
            val_loss += loss.item() * batch['batch_size']
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += batch['batch_size']
    
    val_acc = correct / total
    print(f"Epoch {epoch+1}: Val Loss = {val_loss/total:.4f}, Val Acc = {val_acc:.4f}")

Temporal Gap Handling

When exclude_target_visit=True (default)

Multiple visits:
  • Input: All visits except the last
  • Target: Last visit’s label
  • Time gap: Months from last input visit to target visit
Single visit:
  • Input: The single visit
  • Target: Same visit’s label (assumes stable state)
  • Time gap: Uses months_to_next if available, otherwise single_visit_horizon

When exclude_target_visit=False

  • Input: All visits
  • Target: Last visit’s label
  • Time gap: Average time gap between consecutive visits for the subject

Implementation Details

Batched Embedding Computation

The loader uses an efficient batched approach:
  1. Collects all graphs from multiple subjects in a batch
  2. Creates a single large PyTorch Geometric batch
  3. Computes all embeddings in one forward pass through the encoder
  4. Reshapes embeddings back into per-subject sequences
This is significantly faster than encoding each subject’s graphs sequentially.

Time Features Integration

When exclude_target_visit=True, the loader creates time features for each graph based on the subject’s prediction horizon. These features are passed to the encoder, allowing time-aware GNN models to incorporate temporal information.

Padding and Masking

Sequences are zero-padded to the maximum length in each batch. The lengths tensor allows models to properly mask padded positions during LSTM processing.

Notes

  • Subject IDs are grouped by base ID (splitting on '_run')
  • Shuffling uses a seeded random number generator for reproducibility
  • The encoder is set to eval mode during embedding computation
  • Empty batches (no valid subjects) are skipped
  • Time normalization is applied using the temporal_gap_processor.normalize_time_gaps function
  • The actual batch size may be smaller than requested if some subjects are filtered out

Build docs developers (and LLMs) love