Overview
TemporalDataLoader is a custom data loader that groups brain graphs by subject and creates padded temporal sequences for processing with recurrent neural networks (LSTMs). It handles variable-length sequences, temporal gap normalization, and batched graph encoding.
Class Definition
Constructor
Parameters
The dataset containing brain graphs with temporal information. Should have graphs with
subj_id, visit_months, and months_to_next attributes.List of dataset indices to include in this data loader (e.g., train, validation, or test split indices).
Graph encoder model (GNN) used to compute graph embeddings. Should accept
(x, edge_index, batch, time_features) as input.Device to place tensors on (e.g.,
torch.device('cuda') or torch.device('cpu')).Number of subjects (sequences) per batch.
Whether to shuffle subjects at the start of each epoch.
Random seed for reproducible shuffling.
If True, excludes the last visit from input sequence and uses it as prediction target. If False, uses all visits as input.
Method for normalizing temporal gaps. Options include
'log', 'linear', etc. (handled by temporal_gap_processor.normalize_time_gaps).Default prediction horizon (in months) for subjects with only one visit.
Attributes
Dictionary mapping base subject IDs to lists of dataset indices (grouped by subject).
List of unique subject IDs in the data loader.
Methods
__iter__
Iterates over temporal sequences in batches.dict: Batch dictionary containing:graph_seq(torch.Tensor): Padded graph embeddings of shape (batch_size, max_seq_len, embed_dim)lengths(torch.Tensor): Actual sequence lengths for each subject, shape (batch_size,)labels(torch.Tensor): Target labels for each subject, shape (batch_size,)time_gaps(torch.Tensor): Normalized temporal gaps to predict, shape (batch_size,)batch_size(int): Actual number of sequences in the batch
- Groups graphs by subject
- Computes embeddings for all graphs in a batch using the encoder
- Handles temporal gap calculation based on
exclude_target_visitsetting - Creates padded sequences for variable-length inputs
__len__
Returns the number of batches per epoch.int: Number of batches (ceiling division of number of subjects by batch size)
_group_by_subject
Groups dataset indices by subject ID.dict: Dictionary mapping base subject IDs to lists of their graph indices
- Extracts base subject ID by splitting on
'_run' - Handles cases where
subj_idattribute may not exist
_pad_sequences
Pads sequences to the same length and creates batch tensors.sequences(list[torch.Tensor]): List of embedding sequences (variable length)labels(list[int]): List of corresponding labels
tuple: (padded_sequences, lengths, labels_tensor)padded_sequences: Tensor of shape (batch_size, max_len, embed_dim)lengths: Tensor of actual sequence lengthslabels_tensor: Tensor of labels
Usage Example
Temporal Gap Handling
When exclude_target_visit=True (default)
Multiple visits:
- Input: All visits except the last
- Target: Last visit’s label
- Time gap: Months from last input visit to target visit
- Input: The single visit
- Target: Same visit’s label (assumes stable state)
- Time gap: Uses
months_to_nextif available, otherwisesingle_visit_horizon
When exclude_target_visit=False
- Input: All visits
- Target: Last visit’s label
- Time gap: Average time gap between consecutive visits for the subject
Implementation Details
Batched Embedding Computation
The loader uses an efficient batched approach:- Collects all graphs from multiple subjects in a batch
- Creates a single large PyTorch Geometric batch
- Computes all embeddings in one forward pass through the encoder
- Reshapes embeddings back into per-subject sequences
Time Features Integration
Whenexclude_target_visit=True, the loader creates time features for each graph based on the subject’s prediction horizon. These features are passed to the encoder, allowing time-aware GNN models to incorporate temporal information.
Padding and Masking
Sequences are zero-padded to the maximum length in each batch. Thelengths tensor allows models to properly mask padded positions during LSTM processing.
Notes
- Subject IDs are grouped by base ID (splitting on
'_run') - Shuffling uses a seeded random number generator for reproducibility
- The encoder is set to eval mode during embedding computation
- Empty batches (no valid subjects) are skipped
- Time normalization is applied using the
temporal_gap_processor.normalize_time_gapsfunction - The actual batch size may be smaller than requested if some subjects are filtered out