Overview
TemporalTabGNNClassifier is an LSTM-based model that processes sequences of graph embeddings (optionally combined with tabular features) for temporal classification tasks like AD conversion prediction.
Class Signature
Parameters
Dimension of graph embeddings from GNN encoder
Dimension of tabular embeddings. Set to
0 for graph-only modelsLSTM hidden state dimension
Number of stacked LSTM layers
Dropout probability for LSTM and classifier (only applied if
num_layers > 1)Whether to use bidirectional LSTM
Number of output classes (e.g., 2 for binary classification)
Forward Method
Parameters
Graph embedding sequence of shape
[batch_size, max_seq_len, graph_emb_dim]Optional tabular embedding sequence of shape
[batch_size, max_seq_len, tab_emb_dim]. Can be None if tab_emb_dim=0True sequence lengths for each sample in batch, shape
[batch_size]. Used for packed sequencesAttention mask of shape
[batch_size, max_seq_len]. True for real data, False for paddingReturns
Classification logits of shape
[batch_size, num_classes]Architecture Details
Input Fusion
If tabular features provided:LSTM Processing
Withlengths provided (packed sequences):
Hidden State Extraction
Bidirectional:Classification Head
Example Usage
From main.py:257:Training Example
Notes
- Input dimension is
graph_emb_dim + tab_emb_dim - LSTM output dimension is
hidden_dim * (2 if bidirectional else 1) - Packed sequences improve efficiency for variable-length inputs
- Set
tab_emb_dim=0for graph-only models - Dropout only applied between LSTM layers when
num_layers > 1