Skip to main content

Overview

RNNPredictor is a vanilla RNN-based model for temporal sequence processing. It provides the simplest recurrent architecture for baseline comparisons.

Class Signature

class RNNPredictor(nn.Module):
    def __init__(
        self,
        graph_emb_dim: int = 256,
        tab_emb_dim: int = 64,
        hidden_dim: int = 128,
        num_layers: int = 1,
        dropout: float = 0.3,
        bidirectional: bool = False,
        num_classes: int = 2
    )

Parameters

graph_emb_dim
int
default:"256"
Dimension of graph embeddings from GNN encoder
tab_emb_dim
int
default:"64"
Dimension of tabular embeddings. Set to 0 for graph-only models
hidden_dim
int
default:"128"
RNN hidden state dimension
num_layers
int
default:"1"
Number of stacked RNN layers
dropout
float
default:"0.3"
Dropout probability (only applied if num_layers > 1)
bidirectional
bool
default:"False"
Whether to use bidirectional RNN
num_classes
int
default:"2"
Number of output classes

Forward Method

def forward(
    self,
    graph_seq: torch.Tensor,
    tab_seq: Optional[torch.Tensor] = None,
    lengths: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None
) -> torch.Tensor

Parameters

graph_seq
torch.Tensor
Graph embedding sequence of shape [batch_size, max_seq_len, graph_emb_dim]
tab_seq
torch.Tensor
default:"None"
Optional tabular embedding sequence of shape [batch_size, max_seq_len, tab_emb_dim]
lengths
torch.Tensor
default:"None"
True sequence lengths, shape [batch_size]. Used for packed sequences
mask
torch.Tensor
default:"None"
Attention mask of shape [batch_size, max_seq_len]

Returns

logits
torch.Tensor
Classification logits of shape [batch_size, num_classes]

Architecture Details

RNN Configuration

self.rnn = nn.RNN(
    input_size=self.input_dim,
    hidden_size=hidden_dim,
    num_layers=num_layers,
    batch_first=True,
    dropout=dropout if num_layers > 1 else 0,
    bidirectional=bidirectional,
    nonlinearity='tanh'  # or 'relu'
)

Input Fusion

if tab_seq is None or self.tab_emb_dim == 0:
    fused = graph_seq
else:
    fused = torch.cat([graph_seq, tab_seq], dim=-1)

RNN Processing

With packed sequences:
fused_packed = nn.utils.rnn.pack_padded_sequence(
    fused, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output_packed, h_n = self.rnn(fused_packed)
Standard:
output, h_n = self.rnn(fused)

Hidden State Extraction

Bidirectional:
final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
Unidirectional:
final_hidden = h_n[-1]

Classification Head

nn.Sequential(
    nn.Linear(rnn_output_dim, 64),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(64, num_classes)
)

Example Usage

From main.py:277:
classifier = RNNPredictor(
    graph_emb_dim=512,
    tab_emb_dim=0,
    hidden_dim=opt.lstm_hidden_dim,  # 64
    num_layers=opt.lstm_num_layers,  # 1
    dropout=0.45,
    bidirectional=False,
    num_classes=2
).to(device)

Training Example

for batch in train_loader:
    graph_seq = batch['graph_seq']
    lengths = batch['lengths']
    labels = batch['labels']
    
    logits = classifier(graph_seq, None, lengths)
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()

Vanilla RNN Characteristics

Advantages

  • Simplest recurrent architecture
  • Fewest parameters
  • Good baseline for comparison
  • Fast training

Limitations

  • Suffers from vanishing gradients
  • Struggles with long-term dependencies
  • Less expressive than LSTM/GRU
  • Rarely outperforms gated variants

When to Use RNN

  • As a baseline to validate that temporal modeling helps
  • Very short sequences (< 5 time steps)
  • When extreme computational efficiency is required
  • When other models overfit

Comparison: RNN vs GRU vs LSTM

ModelParametersMemoryLong-termSpeed
RNNFewestLowPoorFastest
GRUMediumMediumGoodFast
LSTMMostHighBestSlower

Notes

  • Uses tanh nonlinearity by default (can be changed to relu)
  • No gating mechanisms (unlike GRU/LSTM)
  • RNN output dimension: hidden_dim * (2 if bidirectional else 1)
  • Typically trained with bidirectional=False in STGNN framework
  • Best suited for simple temporal patterns or as a baseline

Build docs developers (and LLMs) love