Skip to main content

Overview

TemporalTabGNNClassifier is an LSTM-based model that processes sequences of graph embeddings (optionally combined with tabular features) for temporal classification tasks like AD conversion prediction.

Class Signature

class TemporalTabGNNClassifier(nn.Module):
    def __init__(
        self,
        graph_emb_dim: int = 256,
        tab_emb_dim: int = 64,
        hidden_dim: int = 128,
        num_layers: int = 1,
        dropout: float = 0.3,
        bidirectional: bool = False,
        num_classes: int = 2
    )

Parameters

graph_emb_dim
int
default:"256"
Dimension of graph embeddings from GNN encoder
tab_emb_dim
int
default:"64"
Dimension of tabular embeddings. Set to 0 for graph-only models
hidden_dim
int
default:"128"
LSTM hidden state dimension
num_layers
int
default:"1"
Number of stacked LSTM layers
dropout
float
default:"0.3"
Dropout probability for LSTM and classifier (only applied if num_layers > 1)
bidirectional
bool
default:"False"
Whether to use bidirectional LSTM
num_classes
int
default:"2"
Number of output classes (e.g., 2 for binary classification)

Forward Method

def forward(
    self,
    graph_seq: torch.Tensor,
    tab_seq: Optional[torch.Tensor] = None,
    lengths: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None
) -> torch.Tensor

Parameters

graph_seq
torch.Tensor
Graph embedding sequence of shape [batch_size, max_seq_len, graph_emb_dim]
tab_seq
torch.Tensor
default:"None"
Optional tabular embedding sequence of shape [batch_size, max_seq_len, tab_emb_dim]. Can be None if tab_emb_dim=0
lengths
torch.Tensor
default:"None"
True sequence lengths for each sample in batch, shape [batch_size]. Used for packed sequences
mask
torch.Tensor
default:"None"
Attention mask of shape [batch_size, max_seq_len]. True for real data, False for padding

Returns

logits
torch.Tensor
Classification logits of shape [batch_size, num_classes]

Architecture Details

Input Fusion

If tabular features provided:
fused = torch.cat([graph_seq, tab_seq], dim=-1)  # [B, T, graph_emb_dim + tab_emb_dim]
Graph-only mode:
fused = graph_seq  # [B, T, graph_emb_dim]

LSTM Processing

With lengths provided (packed sequences):
fused_packed = nn.utils.rnn.pack_padded_sequence(fused, lengths.cpu(), batch_first=True, enforce_sorted=False)
output_packed, (h_n, c_n) = self.lstm(fused_packed)
Standard processing:
output, (h_n, c_n) = self.lstm(fused)

Hidden State Extraction

Bidirectional:
final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)  # [B, 2*hidden_dim]
Unidirectional:
final_hidden = h_n[-1]  # [B, hidden_dim]

Classification Head

nn.Sequential(
    nn.Linear(lstm_output_dim, 64),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(64, num_classes)
)

Example Usage

From main.py:257:
classifier = TemporalTabGNNClassifier(
    graph_emb_dim=512,  # GNN output is 256*2 from pooling
    tab_emb_dim=0,      # Graph-only mode
    hidden_dim=opt.lstm_hidden_dim,  # 64
    num_layers=opt.lstm_num_layers,  # 1
    dropout=0.45,
    bidirectional=opt.lstm_bidirectional,  # True
    num_classes=2
).to(device)

Training Example

for batch in train_loader:
    graph_seq = batch['graph_seq']  # [B, T, 512]
    lengths = batch['lengths']       # [B]
    labels = batch['labels']         # [B]
    
    logits = classifier(graph_seq, None, lengths)
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()

Notes

  • Input dimension is graph_emb_dim + tab_emb_dim
  • LSTM output dimension is hidden_dim * (2 if bidirectional else 1)
  • Packed sequences improve efficiency for variable-length inputs
  • Set tab_emb_dim=0 for graph-only models
  • Dropout only applied between LSTM layers when num_layers > 1

Build docs developers (and LLMs) love