Skip to main content

Overview

The temporal predictor modules process sequences of graph embeddings to model Alzheimer’s disease progression over time. Three recurrent architectures are available:
  • LSTM (TemporalTabGNNClassifier): Long Short-Term Memory with gating mechanisms
  • GRU (GRUPredictor): Gated Recurrent Unit with simplified gating
  • RNN (RNNPredictor): Basic recurrent neural network with tanh activation
All predictors support:
  • Bidirectional processing
  • Optional tabular data fusion
  • Packed sequences for variable-length inputs
  • Multi-layer stacking

LSTM Predictor

Class Signature

class TemporalTabGNNClassifier(nn.Module):
    def __init__(self,
                 graph_emb_dim: int = 256,
                 tab_emb_dim: int = 64,
                 hidden_dim: int = 128,
                 num_layers: int = 1,
                 dropout: float = 0.3,
                 bidirectional: bool = False,
                 num_classes: int = 2)

Parameters

graph_emb_dim
int
default:256
Dimension of graph embeddings from the GNN encoder. This should match output_dim * 2 from GraphNeuralNetwork (e.g., 256 for output_dim=128).
tab_emb_dim
int
default:64
Dimension of tabular feature embeddings. Set to 0 to create a graph-only model without tabular data.
hidden_dim
int
default:128
Dimension of LSTM hidden state. Controls the memory capacity of the recurrent network.
num_layers
int
default:1
Number of stacked LSTM layers. Dropout is automatically applied between layers when num_layers > 1.
dropout
float
Dropout rate applied:
  • Between LSTM layers (when num_layers > 1)
  • In the classification head
bidirectional
bool
default:false
Whether to use bidirectional LSTM. When True, processes sequences in both forward and backward directions, doubling the output dimension.
num_classes
int
default:2
Number of output classes for classification. Default is 2 for binary classification (e.g., CN vs AD).

Forward Method

def forward(self,
            graph_seq: torch.Tensor,           # [B, T, graph_emb_dim]
            tab_seq: Optional[torch.Tensor] = None,  # [B, T, tab_emb_dim]
            lengths: Optional[torch.Tensor] = None,  # [B]
            mask: Optional[torch.Tensor] = None      # [B, T]
            ) -> torch.Tensor:
    """
    Args:
        graph_seq: Encoded graph sequence [B, T, 256]
        tab_seq: Encoded tabular sequence [B, T, 64] (optional)
        lengths: True sequence lengths [B] for packed sequences
        mask: Attention mask [B, T] (True for real data, False for padding)
    
    Returns:
        logits: Classification logits [B, num_classes]
    """

Architecture

The LSTM predictor consists of three stages:
  1. Input Fusion
    if tab_seq is not None:
        fused = torch.cat([graph_seq, tab_seq], dim=-1)  # [B, T, 320]
    else:
        fused = graph_seq  # [B, T, 256]
    
  2. LSTM Processing
    self.lstm = nn.LSTM(
        input_size=graph_emb_dim + tab_emb_dim,
        hidden_size=hidden_dim,
        num_layers=num_layers,
        batch_first=True,
        dropout=dropout if num_layers > 1 else 0,
        bidirectional=bidirectional
    )
    
  3. Classification Head
    self.classifier = nn.Sequential(
        nn.Linear(lstm_output_dim, 64),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(64, num_classes)
    )
    

Usage Example

import torch
from TemporalPredictor import TemporalTabGNNClassifier

# Initialize LSTM predictor
predictor = TemporalTabGNNClassifier(
    graph_emb_dim=512,    # GNN outputs 256*2=512
    tab_emb_dim=64,
    hidden_dim=128,
    num_layers=2,
    dropout=0.3,
    bidirectional=True,
    num_classes=2
)

# Forward pass with both modalities
graph_seq = torch.randn(4, 5, 512)  # 4 patients, 5 visits, 512-dim embeddings
tab_seq = torch.randn(4, 5, 64)     # 4 patients, 5 visits, 64-dim tabular features
lengths = torch.tensor([5, 4, 3, 5]) # True sequence lengths

logits = predictor(graph_seq, tab_seq, lengths)
print(logits.shape)  # torch.Size([4, 2])

# Predictions
probs = torch.softmax(logits, dim=1)
predictions = torch.argmax(logits, dim=1)

GRU Predictor

Class Signature

class GRUPredictor(nn.Module):
    def __init__(self,
                 graph_emb_dim: int = 256,
                 tab_emb_dim: int = 64,
                 hidden_dim: int = 128,
                 num_layers: int = 1,
                 dropout: float = 0.3,
                 bidirectional: bool = False,
                 num_classes: int = 2)

Key Differences from LSTM

  • Simpler gating mechanism: Uses update and reset gates instead of input, forget, and output gates
  • No cell state: Only maintains hidden state (no separate cell state like LSTM)
  • Faster training: Fewer parameters and computations than LSTM
  • Similar performance: Often comparable to LSTM on many tasks

Architecture

self.gru = nn.GRU(
    input_size=graph_emb_dim + tab_emb_dim,
    hidden_size=hidden_dim,
    num_layers=num_layers,
    batch_first=True,
    dropout=dropout if num_layers > 1 else 0,
    bidirectional=bidirectional
)

Usage Example

from GRUPredictor import GRUPredictor

# Initialize GRU predictor
predictor = GRUPredictor(
    graph_emb_dim=512,
    tab_emb_dim=64,
    hidden_dim=128,
    num_layers=1,
    dropout=0.3,
    bidirectional=False,
    num_classes=2
)

# Forward pass
graph_seq = torch.randn(4, 5, 512)
tab_seq = torch.randn(4, 5, 64)
lengths = torch.tensor([5, 4, 3, 5])

logits = predictor(graph_seq, tab_seq, lengths)
print(logits.shape)  # torch.Size([4, 2])

RNN Predictor

Class Signature

class RNNPredictor(nn.Module):
    def __init__(self,
                 graph_emb_dim: int = 256,
                 tab_emb_dim: int = 64,
                 hidden_dim: int = 128,
                 num_layers: int = 1,
                 dropout: float = 0.3,
                 bidirectional: bool = False,
                 num_classes: int = 2)

Key Differences

  • No gating mechanism: Basic recurrent connections with tanh activation
  • Vanishing gradient issues: More susceptible to vanishing gradients on long sequences
  • Fastest training: Minimal computational overhead
  • Baseline model: Often used as a baseline for comparison

Architecture

self.rnn = nn.RNN(
    input_size=graph_emb_dim + tab_emb_dim,
    hidden_size=hidden_dim,
    num_layers=num_layers,
    batch_first=True,
    dropout=dropout if num_layers > 1 else 0,
    bidirectional=bidirectional,
    nonlinearity='tanh'  # or 'relu'
)

Usage Example

from RNNPredictor import RNNPredictor

# Initialize RNN predictor (baseline)
predictor = RNNPredictor(
    graph_emb_dim=512,
    tab_emb_dim=0,  # Graph-only model
    hidden_dim=128,
    num_layers=1,
    dropout=0.3,
    bidirectional=False,
    num_classes=2
)

# Forward pass (no tabular data)
graph_seq = torch.randn(4, 5, 512)
logits = predictor(graph_seq, tab_seq=None)
print(logits.shape)  # torch.Size([4, 2])

Architecture Comparison

FeatureLSTMGRURNN
Gating3 gates (input, forget, output)2 gates (update, reset)None
StateHidden + Cell stateHidden state onlyHidden state only
ParametersMost (4×hidden²)Medium (3×hidden²)Fewest (hidden²)
Training SpeedSlowestMediumFastest
Long sequencesBestGoodPoor (vanishing gradient)
MemoryHighestMediumLowest
Use caseDefault choice, long dependenciesFaster alternative to LSTMBaseline/short sequences

Packed Sequences

All predictors support packed sequences for efficient batch processing with variable-length inputs:
if lengths is not None:
    # Pack sequences to skip padding computations
    fused_packed = nn.utils.rnn.pack_padded_sequence(
        fused, lengths.cpu(), batch_first=True, enforce_sorted=False
    )
    output_packed, (h_n, c_n) = self.lstm(fused_packed)
else:
    # Standard processing (padding included)
    output, (h_n, c_n) = self.lstm(fused)
Benefits:
  • Skip computations on padded timesteps
  • More efficient memory usage
  • Faster training on variable-length sequences

Bidirectional Processing

When bidirectional=True, the model processes sequences in both directions:
if self.bidirectional:
    # Concatenate forward and backward hidden states
    final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)  # [B, 2*hidden_dim]
else:
    final_hidden = h_n[-1]  # [B, hidden_dim]
Trade-offs:
  • Pros: Captures both past and future context, often improves performance
  • Cons: 2× parameters, 2× computation, cannot do online/causal prediction

Graph-Only Mode

To use predictors without tabular data, set tab_emb_dim=0:
predictor = TemporalTabGNNClassifier(
    graph_emb_dim=512,
    tab_emb_dim=0,  # Disable tabular features
    hidden_dim=128,
    num_classes=2
)

# Forward pass without tab_seq
logits = predictor(graph_seq, tab_seq=None)
The model automatically skips concatenation:
if tab_seq is None or self.tab_emb_dim == 0:
    fused = graph_seq
else:
    fused = torch.cat([graph_seq, tab_seq], dim=-1)

Output Dimensions

ConfigurationLSTM OutputClassifier InputFinal Output
Unidirectionalhidden_dimhidden_dimnum_classes
Bidirectional2 × hidden_dim2 × hidden_dimnum_classes
2 layers, hidden=128128 or 256128 or 2562

Training Tips

When to Use Each Architecture

LSTM:
  • Default choice for most applications
  • Long sequences (> 10 timesteps)
  • Complex temporal dependencies
  • When computational cost is not critical
GRU:
  • Good alternative to LSTM with 25% fewer parameters
  • Medium-length sequences (5-15 timesteps)
  • When training speed matters
  • Limited computational resources
RNN:
  • Baseline model for comparison
  • Very short sequences (< 5 timesteps)
  • When interpretability is important
  • Fastest inference time

Hyperparameter Recommendations

# Small dataset (< 500 samples)
predictor = TemporalTabGNNClassifier(
    hidden_dim=64,
    num_layers=1,
    dropout=0.5,
    bidirectional=False
)

# Medium dataset (500-2000 samples)
predictor = TemporalTabGNNClassifier(
    hidden_dim=128,
    num_layers=2,
    dropout=0.3,
    bidirectional=True
)

# Large dataset (> 2000 samples)
predictor = TemporalTabGNNClassifier(
    hidden_dim=256,
    num_layers=3,
    dropout=0.2,
    bidirectional=True
)

See Also

Build docs developers (and LLMs) love