Overview
The temporal predictor modules process sequences of graph embeddings to model Alzheimer’s disease progression over time. Three recurrent architectures are available:
- LSTM (TemporalTabGNNClassifier): Long Short-Term Memory with gating mechanisms
- GRU (GRUPredictor): Gated Recurrent Unit with simplified gating
- RNN (RNNPredictor): Basic recurrent neural network with tanh activation
All predictors support:
- Bidirectional processing
- Optional tabular data fusion
- Packed sequences for variable-length inputs
- Multi-layer stacking
LSTM Predictor
Class Signature
class TemporalTabGNNClassifier(nn.Module):
def __init__(self,
graph_emb_dim: int = 256,
tab_emb_dim: int = 64,
hidden_dim: int = 128,
num_layers: int = 1,
dropout: float = 0.3,
bidirectional: bool = False,
num_classes: int = 2)
Parameters
Dimension of graph embeddings from the GNN encoder. This should match output_dim * 2 from GraphNeuralNetwork (e.g., 256 for output_dim=128).
Dimension of tabular feature embeddings. Set to 0 to create a graph-only model without tabular data.
Dimension of LSTM hidden state. Controls the memory capacity of the recurrent network.
Number of stacked LSTM layers. Dropout is automatically applied between layers when num_layers > 1.
Dropout rate applied:
- Between LSTM layers (when
num_layers > 1)
- In the classification head
Whether to use bidirectional LSTM. When True, processes sequences in both forward and backward directions, doubling the output dimension.
Number of output classes for classification. Default is 2 for binary classification (e.g., CN vs AD).
Forward Method
def forward(self,
graph_seq: torch.Tensor, # [B, T, graph_emb_dim]
tab_seq: Optional[torch.Tensor] = None, # [B, T, tab_emb_dim]
lengths: Optional[torch.Tensor] = None, # [B]
mask: Optional[torch.Tensor] = None # [B, T]
) -> torch.Tensor:
"""
Args:
graph_seq: Encoded graph sequence [B, T, 256]
tab_seq: Encoded tabular sequence [B, T, 64] (optional)
lengths: True sequence lengths [B] for packed sequences
mask: Attention mask [B, T] (True for real data, False for padding)
Returns:
logits: Classification logits [B, num_classes]
"""
Architecture
The LSTM predictor consists of three stages:
-
Input Fusion
if tab_seq is not None:
fused = torch.cat([graph_seq, tab_seq], dim=-1) # [B, T, 320]
else:
fused = graph_seq # [B, T, 256]
-
LSTM Processing
self.lstm = nn.LSTM(
input_size=graph_emb_dim + tab_emb_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
-
Classification Head
self.classifier = nn.Sequential(
nn.Linear(lstm_output_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, num_classes)
)
Usage Example
import torch
from TemporalPredictor import TemporalTabGNNClassifier
# Initialize LSTM predictor
predictor = TemporalTabGNNClassifier(
graph_emb_dim=512, # GNN outputs 256*2=512
tab_emb_dim=64,
hidden_dim=128,
num_layers=2,
dropout=0.3,
bidirectional=True,
num_classes=2
)
# Forward pass with both modalities
graph_seq = torch.randn(4, 5, 512) # 4 patients, 5 visits, 512-dim embeddings
tab_seq = torch.randn(4, 5, 64) # 4 patients, 5 visits, 64-dim tabular features
lengths = torch.tensor([5, 4, 3, 5]) # True sequence lengths
logits = predictor(graph_seq, tab_seq, lengths)
print(logits.shape) # torch.Size([4, 2])
# Predictions
probs = torch.softmax(logits, dim=1)
predictions = torch.argmax(logits, dim=1)
GRU Predictor
Class Signature
class GRUPredictor(nn.Module):
def __init__(self,
graph_emb_dim: int = 256,
tab_emb_dim: int = 64,
hidden_dim: int = 128,
num_layers: int = 1,
dropout: float = 0.3,
bidirectional: bool = False,
num_classes: int = 2)
Key Differences from LSTM
- Simpler gating mechanism: Uses update and reset gates instead of input, forget, and output gates
- No cell state: Only maintains hidden state (no separate cell state like LSTM)
- Faster training: Fewer parameters and computations than LSTM
- Similar performance: Often comparable to LSTM on many tasks
Architecture
self.gru = nn.GRU(
input_size=graph_emb_dim + tab_emb_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
Usage Example
from GRUPredictor import GRUPredictor
# Initialize GRU predictor
predictor = GRUPredictor(
graph_emb_dim=512,
tab_emb_dim=64,
hidden_dim=128,
num_layers=1,
dropout=0.3,
bidirectional=False,
num_classes=2
)
# Forward pass
graph_seq = torch.randn(4, 5, 512)
tab_seq = torch.randn(4, 5, 64)
lengths = torch.tensor([5, 4, 3, 5])
logits = predictor(graph_seq, tab_seq, lengths)
print(logits.shape) # torch.Size([4, 2])
RNN Predictor
Class Signature
class RNNPredictor(nn.Module):
def __init__(self,
graph_emb_dim: int = 256,
tab_emb_dim: int = 64,
hidden_dim: int = 128,
num_layers: int = 1,
dropout: float = 0.3,
bidirectional: bool = False,
num_classes: int = 2)
Key Differences
- No gating mechanism: Basic recurrent connections with tanh activation
- Vanishing gradient issues: More susceptible to vanishing gradients on long sequences
- Fastest training: Minimal computational overhead
- Baseline model: Often used as a baseline for comparison
Architecture
self.rnn = nn.RNN(
input_size=graph_emb_dim + tab_emb_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional,
nonlinearity='tanh' # or 'relu'
)
Usage Example
from RNNPredictor import RNNPredictor
# Initialize RNN predictor (baseline)
predictor = RNNPredictor(
graph_emb_dim=512,
tab_emb_dim=0, # Graph-only model
hidden_dim=128,
num_layers=1,
dropout=0.3,
bidirectional=False,
num_classes=2
)
# Forward pass (no tabular data)
graph_seq = torch.randn(4, 5, 512)
logits = predictor(graph_seq, tab_seq=None)
print(logits.shape) # torch.Size([4, 2])
Architecture Comparison
| Feature | LSTM | GRU | RNN |
|---|
| Gating | 3 gates (input, forget, output) | 2 gates (update, reset) | None |
| State | Hidden + Cell state | Hidden state only | Hidden state only |
| Parameters | Most (4×hidden²) | Medium (3×hidden²) | Fewest (hidden²) |
| Training Speed | Slowest | Medium | Fastest |
| Long sequences | Best | Good | Poor (vanishing gradient) |
| Memory | Highest | Medium | Lowest |
| Use case | Default choice, long dependencies | Faster alternative to LSTM | Baseline/short sequences |
Packed Sequences
All predictors support packed sequences for efficient batch processing with variable-length inputs:
if lengths is not None:
# Pack sequences to skip padding computations
fused_packed = nn.utils.rnn.pack_padded_sequence(
fused, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output_packed, (h_n, c_n) = self.lstm(fused_packed)
else:
# Standard processing (padding included)
output, (h_n, c_n) = self.lstm(fused)
Benefits:
- Skip computations on padded timesteps
- More efficient memory usage
- Faster training on variable-length sequences
Bidirectional Processing
When bidirectional=True, the model processes sequences in both directions:
if self.bidirectional:
# Concatenate forward and backward hidden states
final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1) # [B, 2*hidden_dim]
else:
final_hidden = h_n[-1] # [B, hidden_dim]
Trade-offs:
- Pros: Captures both past and future context, often improves performance
- Cons: 2× parameters, 2× computation, cannot do online/causal prediction
Graph-Only Mode
To use predictors without tabular data, set tab_emb_dim=0:
predictor = TemporalTabGNNClassifier(
graph_emb_dim=512,
tab_emb_dim=0, # Disable tabular features
hidden_dim=128,
num_classes=2
)
# Forward pass without tab_seq
logits = predictor(graph_seq, tab_seq=None)
The model automatically skips concatenation:
if tab_seq is None or self.tab_emb_dim == 0:
fused = graph_seq
else:
fused = torch.cat([graph_seq, tab_seq], dim=-1)
Output Dimensions
| Configuration | LSTM Output | Classifier Input | Final Output |
|---|
| Unidirectional | hidden_dim | hidden_dim | num_classes |
| Bidirectional | 2 × hidden_dim | 2 × hidden_dim | num_classes |
| 2 layers, hidden=128 | 128 or 256 | 128 or 256 | 2 |
Training Tips
When to Use Each Architecture
LSTM:
- Default choice for most applications
- Long sequences (> 10 timesteps)
- Complex temporal dependencies
- When computational cost is not critical
GRU:
- Good alternative to LSTM with 25% fewer parameters
- Medium-length sequences (5-15 timesteps)
- When training speed matters
- Limited computational resources
RNN:
- Baseline model for comparison
- Very short sequences (< 5 timesteps)
- When interpretability is important
- Fastest inference time
Hyperparameter Recommendations
# Small dataset (< 500 samples)
predictor = TemporalTabGNNClassifier(
hidden_dim=64,
num_layers=1,
dropout=0.5,
bidirectional=False
)
# Medium dataset (500-2000 samples)
predictor = TemporalTabGNNClassifier(
hidden_dim=128,
num_layers=2,
dropout=0.3,
bidirectional=True
)
# Large dataset (> 2000 samples)
predictor = TemporalTabGNNClassifier(
hidden_dim=256,
num_layers=3,
dropout=0.2,
bidirectional=True
)
See Also