Skip to main content

Overview

GRUPredictor is a GRU-based model for processing temporal sequences of graph embeddings. It provides a lighter alternative to LSTM for temporal classification tasks.

Class Signature

class GRUPredictor(nn.Module):
    def __init__(
        self,
        graph_emb_dim: int = 256,
        tab_emb_dim: int = 64,
        hidden_dim: int = 128,
        num_layers: int = 1,
        dropout: float = 0.3,
        bidirectional: bool = False,
        num_classes: int = 2
    )

Parameters

graph_emb_dim
int
default:"256"
Dimension of graph embeddings from GNN encoder
tab_emb_dim
int
default:"64"
Dimension of tabular embeddings. Set to 0 for graph-only models
hidden_dim
int
default:"128"
GRU hidden state dimension
num_layers
int
default:"1"
Number of stacked GRU layers
dropout
float
default:"0.3"
Dropout probability (only applied if num_layers > 1)
bidirectional
bool
default:"False"
Whether to use bidirectional GRU
num_classes
int
default:"2"
Number of output classes

Forward Method

def forward(
    self,
    graph_seq: torch.Tensor,
    tab_seq: Optional[torch.Tensor] = None,
    lengths: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None
) -> torch.Tensor

Parameters

graph_seq
torch.Tensor
Graph embedding sequence of shape [batch_size, max_seq_len, graph_emb_dim]
tab_seq
torch.Tensor
default:"None"
Optional tabular embedding sequence of shape [batch_size, max_seq_len, tab_emb_dim]
lengths
torch.Tensor
default:"None"
True sequence lengths, shape [batch_size]. Used for packed sequences
mask
torch.Tensor
default:"None"
Attention mask of shape [batch_size, max_seq_len]

Returns

logits
torch.Tensor
Classification logits of shape [batch_size, num_classes]

Architecture Details

Input Fusion

if tab_seq is None or self.tab_emb_dim == 0:
    fused = graph_seq
else:
    fused = torch.cat([graph_seq, tab_seq], dim=-1)

GRU Processing

With packed sequences:
fused_packed = nn.utils.rnn.pack_padded_sequence(
    fused, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output_packed, h_n = self.gru(fused_packed)
Standard:
output, h_n = self.gru(fused)

Hidden State Extraction

Bidirectional:
final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
Unidirectional:
final_hidden = h_n[-1]

Classification Head

nn.Sequential(
    nn.Linear(gru_output_dim, 64),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(64, num_classes)
)

Example Usage

From main.py:267:
classifier = GRUPredictor(
    graph_emb_dim=512,
    tab_emb_dim=0,
    hidden_dim=opt.lstm_hidden_dim,  # 64
    num_layers=opt.lstm_num_layers,  # 1
    dropout=0.45,
    bidirectional=False,
    num_classes=2
).to(device)

Training Example

for batch in train_loader:
    graph_seq = batch['graph_seq']
    lengths = batch['lengths']
    labels = batch['labels']
    
    logits = classifier(graph_seq, None, lengths)
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()

GRU vs LSTM

Advantages of GRU

  • Fewer parameters (no separate cell state)
  • Faster training and inference
  • Less prone to overfitting on small datasets
  • Simpler architecture

When to Use GRU

  • Smaller datasets with limited temporal patterns
  • When training speed is critical
  • When LSTM overfits
  • Shorter sequences (< 10 time steps)

Notes

  • GRU has no cell state, only hidden state (unlike LSTM)
  • GRU output dimension: hidden_dim * (2 if bidirectional else 1)
  • Compatible with packed sequences for variable-length inputs
  • Dropout only applied between layers when num_layers > 1
  • Typically trained with bidirectional=False in STGNN framework

Build docs developers (and LLMs) love