Skip to main content

Two-Stage Architecture

The STGNN system employs a two-stage architecture that separates spatial feature extraction from temporal sequence modeling.

Stage 1: Graph Neural Network Encoder

GraphNeuralNetwork Class (model.py:13-195)

The spatial encoder transforms brain connectivity graphs into fixed-dimensional embeddings.

Initialization Parameters

GraphNeuralNetwork(
    input_dim=100,           # Node feature dimension
    hidden_dim=128,          # Hidden layer dimension
    output_dim=256,          # Output dimension (before pooling)
    dropout=0.5,             # Dropout probability
    use_topk_pooling=True,   # Enable hierarchical pooling
    topk_ratio=0.5,          # Fraction of nodes to retain
    layer_type="GCN",        # GCN, GAT, or GraphSAGE
    num_layers=3,            # Number of graph convolution layers
    activation='relu',       # relu, leaky_relu, elu, gelu
    use_time_features=False  # Incorporate temporal gap features
)
Default Configuration: The model typically uses hidden_dim=256, output_dim=256, num_layers=2, layer_type="GraphSAGE", and activation="elu" in production settings.

Layer Architecture

The GNN dynamically constructs layers based on num_layers:
Layer PositionInput DimensionOutput Dimension
First (i=0)input_dimhidden_dim
Middle (0 < i < n-1)hidden_dimhidden_dim
Last (i=n-1)hidden_dimoutput_dim
Code reference: model.py:41-66

Forward Pass Modes

Standard Mode (_forward_traditional)

def _forward_traditional(self, x, edge_index, batch):
    for i in range(self.num_layers):
        x = self.convs[i](x, edge_index)      # Graph convolution
        x = self.gns[i](x, batch)             # Graph normalization
        x = self.activation_fn(x)             # Activation function
        x = self.dropout(x)                   # Dropout
    
    # Dual pooling for richer representation
    x_mean = global_mean_pool(x, batch)
    x_max = global_max_pool(x, batch)
    x = torch.cat([x_mean, x_max], dim=1)     # [B, output_dim*2]
    return x
Code reference: model.py:157-171
Output Dimension: Due to mean+max pooling concatenation, the actual output dimension is output_dim * 2 (default: 512D from 256×2).

TopK Pooling Mode (_forward_with_topk)

Hierarchical graph coarsening is applied after each convolution layer:
def _forward_with_topk(self, x, edge_index, batch):
    for i in range(self.num_layers):
        x = self.convs[i](x, edge_index)      # Convolution
        x = self.gns[i](x, batch)             # Normalization
        x = self.activation_fn(x)             # Activation
        x = self.dropout(x)                   # Dropout
        
        # Hierarchical pooling: retain top nodes by importance
        x, edge_index, _, batch, perm, score = self.topk_pools[i](
            x, edge_index, batch=batch
        )
    
    # Global pooling on remaining nodes
    x_mean = global_mean_pool(x, batch)
    x_max = global_max_pool(x, batch)
    x = torch.cat([x_mean, x_max], dim=1)
    return x
Code reference: model.py:136-155
Minimum Retention: To prevent empty graphs, TopK pooling enforces a minimum 30% node retention rate regardless of the configured topk_ratio (model.py:74).

Time-Aware Extension

When use_time_features=True, temporal gap information is fused with spatial embeddings:
if self.use_time_features and time_to_predict is not None:
    # Project time feature to small dimension (32D)
    time_embedding = self.time_projection(time_to_predict)  # [B, 32]
    
    # Concatenate graph (512D) + time (32D)
    combined = torch.cat([graph_embedding, time_embedding], dim=1)  # [B, 544]
    
    # Fuse to final dimension
    final_embedding = self.fusion_layer(combined)  # [B, 512]
Code reference: model.py:120-132
Time Feature Design: Time features use a small 32D projection to prevent temporal information from dominating the 512D graph features, maintaining focus on connectivity patterns.

Stage 2: Temporal Sequence Modeling

TemporalTabGNNClassifier (TemporalPredictor.py:5-86)

The LSTM-based temporal predictor processes sequences of graph embeddings.

Architecture Components

class TemporalTabGNNClassifier(nn.Module):
    def __init__(
        self,
        graph_emb_dim: int = 256,      # Must match GNN output (×2)
        tab_emb_dim: int = 64,         # Optional tabular features
        hidden_dim: int = 128,         # LSTM hidden dimension
        num_layers: int = 1,           # LSTM depth
        dropout: float = 0.3,
        bidirectional: bool = False,
        num_classes: int = 2
    )
self.lstm = nn.LSTM(
    input_size=graph_emb_dim + tab_emb_dim,  # 512 + 64 = 576D
    hidden_size=hidden_dim,
    num_layers=num_layers,
    batch_first=True,
    dropout=dropout if num_layers > 1 else 0,
    bidirectional=bidirectional
)
Code reference: TemporalPredictor.py:24-29

Forward Pass Details

def forward(self, graph_seq, tab_seq=None, lengths=None, mask=None):
    # Step 1: Concatenate modalities
    if tab_seq is None or self.tab_emb_dim == 0:
        fused = graph_seq  # [B, T, 512]
    else:
        fused = torch.cat([graph_seq, tab_seq], dim=-1)  # [B, T, 576]
    
    # Step 2: Process with LSTM (with optional packed sequences)
    if lengths is not None:
        fused_packed = nn.utils.rnn.pack_padded_sequence(
            fused, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        output_packed, (h_n, c_n) = self.lstm(fused_packed)
    else:
        output, (h_n, c_n) = self.lstm(fused)
    
    # Step 3: Extract final hidden state
    if self.bidirectional:
        final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
    else:
        final_hidden = h_n[-1]
    
    # Step 4: Classify
    logits = self.classifier(final_hidden)  # [B, 2]
    return logits
Code reference: TemporalPredictor.py:40-86

Alternative Temporal Architectures

GRU Predictor

File: GRUPredictor.pyUses Gated Recurrent Units instead of LSTM. Simpler architecture with only hidden state (no cell state), potentially faster training.
self.gru = nn.GRU(
    input_size=self.input_dim,
    hidden_size=hidden_dim,
    num_layers=num_layers,
    batch_first=True,
    dropout=dropout if num_layers > 1 else 0,
    bidirectional=bidirectional
)
Code reference: GRUPredictor.py:23-28

RNN Predictor

File: RNNPredictor.pyVanilla RNN with tanh nonlinearity. Simplest architecture, useful for baseline comparisons.
self.rnn = nn.RNN(
    input_size=self.input_dim,
    hidden_size=hidden_dim,
    num_layers=num_layers,
    batch_first=True,
    dropout=dropout if num_layers > 1 else 0,
    bidirectional=bidirectional,
    nonlinearity='tanh'
)
Code reference: RNNPredictor.py:23-29

Data Pipeline Integration

TemporalDataLoader (TemporalDataLoader.py)

The custom data loader bridges the two architecture stages:
1

Group by Subject

Brain scans are grouped by patient ID to create temporal sequences (_group_by_subject())
2

Batch Construction

Multiple subjects are batched together. All their graphs across all visits are collected.
3

Batched Encoding

All graphs in the batch are processed in a single forward pass through the GNN encoder for efficiency.
big_batch = Batch.from_data_list(all_data).to(device)
all_embeddings = encoder(big_batch.x, big_batch.edge_index, 
                         big_batch.batch, time_features_tensor)
Code reference: TemporalDataLoader.py:194-212
4

Sequence Reshaping

Embeddings are reshaped back into per-subject sequences based on visit ordering.
5

Padding

Variable-length sequences are padded to the same length for batch processing.
graph_seq, lengths, labels = self._pad_sequences(batch_sequences, subject_labels)
Code reference: TemporalDataLoader.py:226
Efficiency Gain: Batched encoding processes hundreds of graphs simultaneously rather than sequentially, dramatically reducing computation time.

Training Configuration

Model Selection (main.py:38-52)

parser.add_argument('--model_type', type=str, default='LSTM')
parser.add_argument('--layer_type', type=str, default="GraphSAGE")
parser.add_argument('--gnn_hidden_dim', type=int, default=256)
parser.add_argument('--gnn_num_layers', type=int, default=2)
parser.add_argument('--gnn_activation', type=str, default='elu')
parser.add_argument('--lstm_hidden_dim', type=int, default=64)
parser.add_argument('--lstm_num_layers', type=int, default=1)
parser.add_argument('--lstm_bidirectional', type=bool, default=True)
parser.add_argument('--freeze_encoder', action='store_true')
parser.add_argument('--use_topk_pooling', action='store_true', default=True)
parser.add_argument('--topk_ratio', type=float, default=0.3)
parser.add_argument('--use_time_features', action='store_true')

Encoder Freezing

When --freeze_encoder is set, the GNN weights are frozen during temporal training:
if opt.freeze_encoder:
    for param in encoder.parameters():
        param.requires_grad = False
This is useful when:
  • The encoder is pre-trained on large datasets
  • You want to prevent catastrophic forgetting
  • Temporal training data is limited

Next Steps

GNN Details

Layer types and graph operations

Temporal Modeling

RNN architecture comparisons

Spatiotemporal Integration

How spatial and temporal stages combine

Build docs developers (and LLMs) love