Two-Stage Architecture
The STGNN system employs a two-stage architecture that separates spatial feature extraction from temporal sequence modeling.
Stage 1: Graph Neural Network Encoder
GraphNeuralNetwork Class (model.py:13-195)
The spatial encoder transforms brain connectivity graphs into fixed-dimensional embeddings.
Initialization Parameters
GraphNeuralNetwork(
input_dim = 100 , # Node feature dimension
hidden_dim = 128 , # Hidden layer dimension
output_dim = 256 , # Output dimension (before pooling)
dropout = 0.5 , # Dropout probability
use_topk_pooling = True , # Enable hierarchical pooling
topk_ratio = 0.5 , # Fraction of nodes to retain
layer_type = "GCN" , # GCN, GAT, or GraphSAGE
num_layers = 3 , # Number of graph convolution layers
activation = 'relu' , # relu, leaky_relu, elu, gelu
use_time_features = False # Incorporate temporal gap features
)
Default Configuration : The model typically uses hidden_dim=256, output_dim=256, num_layers=2, layer_type="GraphSAGE", and activation="elu" in production settings.
Layer Architecture
The GNN dynamically constructs layers based on num_layers:
Layer Position Input Dimension Output Dimension First (i=0) input_dimhidden_dimMiddle (0 < i < n-1) hidden_dimhidden_dimLast (i=n-1) hidden_dimoutput_dim
Code reference : model.py:41-66
Forward Pass Modes
Standard Mode (_forward_traditional)
def _forward_traditional ( self , x , edge_index , batch ):
for i in range ( self .num_layers):
x = self .convs[i](x, edge_index) # Graph convolution
x = self .gns[i](x, batch) # Graph normalization
x = self .activation_fn(x) # Activation function
x = self .dropout(x) # Dropout
# Dual pooling for richer representation
x_mean = global_mean_pool(x, batch)
x_max = global_max_pool(x, batch)
x = torch.cat([x_mean, x_max], dim = 1 ) # [B, output_dim*2]
return x
Code reference : model.py:157-171
Output Dimension : Due to mean+max pooling concatenation, the actual output dimension is output_dim * 2 (default: 512D from 256×2).
TopK Pooling Mode (_forward_with_topk)
Hierarchical graph coarsening is applied after each convolution layer:
def _forward_with_topk ( self , x , edge_index , batch ):
for i in range ( self .num_layers):
x = self .convs[i](x, edge_index) # Convolution
x = self .gns[i](x, batch) # Normalization
x = self .activation_fn(x) # Activation
x = self .dropout(x) # Dropout
# Hierarchical pooling: retain top nodes by importance
x, edge_index, _, batch, perm, score = self .topk_pools[i](
x, edge_index, batch = batch
)
# Global pooling on remaining nodes
x_mean = global_mean_pool(x, batch)
x_max = global_max_pool(x, batch)
x = torch.cat([x_mean, x_max], dim = 1 )
return x
Code reference : model.py:136-155
Minimum Retention : To prevent empty graphs, TopK pooling enforces a minimum 30% node retention rate regardless of the configured topk_ratio (model.py:74).
Time-Aware Extension
When use_time_features=True, temporal gap information is fused with spatial embeddings:
if self .use_time_features and time_to_predict is not None :
# Project time feature to small dimension (32D)
time_embedding = self .time_projection(time_to_predict) # [B, 32]
# Concatenate graph (512D) + time (32D)
combined = torch.cat([graph_embedding, time_embedding], dim = 1 ) # [B, 544]
# Fuse to final dimension
final_embedding = self .fusion_layer(combined) # [B, 512]
Code reference : model.py:120-132
Time Feature Design : Time features use a small 32D projection to prevent temporal information from dominating the 512D graph features, maintaining focus on connectivity patterns.
Stage 2: Temporal Sequence Modeling
TemporalTabGNNClassifier (TemporalPredictor.py:5-86)
The LSTM-based temporal predictor processes sequences of graph embeddings.
Architecture Components
class TemporalTabGNNClassifier ( nn . Module ):
def __init__ (
self ,
graph_emb_dim : int = 256 , # Must match GNN output (×2)
tab_emb_dim : int = 64 , # Optional tabular features
hidden_dim : int = 128 , # LSTM hidden dimension
num_layers : int = 1 , # LSTM depth
dropout : float = 0.3 ,
bidirectional : bool = False ,
num_classes : int = 2
)
LSTM Layer
Classifier MLP
self .lstm = nn.LSTM(
input_size = graph_emb_dim + tab_emb_dim, # 512 + 64 = 576D
hidden_size = hidden_dim,
num_layers = num_layers,
batch_first = True ,
dropout = dropout if num_layers > 1 else 0 ,
bidirectional = bidirectional
)
Code reference : TemporalPredictor.py:24-29lstm_output_dim = hidden_dim * ( 2 if bidirectional else 1 )
self .classifier = nn.Sequential(
nn.Linear(lstm_output_dim, 64 ),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear( 64 , num_classes)
)
Code reference : TemporalPredictor.py:31-38
Forward Pass Details
def forward ( self , graph_seq , tab_seq = None , lengths = None , mask = None ):
# Step 1: Concatenate modalities
if tab_seq is None or self .tab_emb_dim == 0 :
fused = graph_seq # [B, T, 512]
else :
fused = torch.cat([graph_seq, tab_seq], dim =- 1 ) # [B, T, 576]
# Step 2: Process with LSTM (with optional packed sequences)
if lengths is not None :
fused_packed = nn.utils.rnn.pack_padded_sequence(
fused, lengths.cpu(), batch_first = True , enforce_sorted = False
)
output_packed, (h_n, c_n) = self .lstm(fused_packed)
else :
output, (h_n, c_n) = self .lstm(fused)
# Step 3: Extract final hidden state
if self .bidirectional:
final_hidden = torch.cat([h_n[ - 2 ], h_n[ - 1 ]], dim = 1 )
else :
final_hidden = h_n[ - 1 ]
# Step 4: Classify
logits = self .classifier(final_hidden) # [B, 2]
return logits
Code reference : TemporalPredictor.py:40-86
Alternative Temporal Architectures
GRU Predictor File : GRUPredictor.pyUses Gated Recurrent Units instead of LSTM. Simpler architecture with only hidden state (no cell state), potentially faster training. self .gru = nn.GRU(
input_size = self .input_dim,
hidden_size = hidden_dim,
num_layers = num_layers,
batch_first = True ,
dropout = dropout if num_layers > 1 else 0 ,
bidirectional = bidirectional
)
Code reference : GRUPredictor.py:23-28
RNN Predictor File : RNNPredictor.pyVanilla RNN with tanh nonlinearity. Simplest architecture, useful for baseline comparisons. self .rnn = nn.RNN(
input_size = self .input_dim,
hidden_size = hidden_dim,
num_layers = num_layers,
batch_first = True ,
dropout = dropout if num_layers > 1 else 0 ,
bidirectional = bidirectional,
nonlinearity = 'tanh'
)
Code reference : RNNPredictor.py:23-29
Data Pipeline Integration
TemporalDataLoader (TemporalDataLoader.py)
The custom data loader bridges the two architecture stages:
Group by Subject
Brain scans are grouped by patient ID to create temporal sequences (_group_by_subject())
Batch Construction
Multiple subjects are batched together. All their graphs across all visits are collected.
Batched Encoding
All graphs in the batch are processed in a single forward pass through the GNN encoder for efficiency. big_batch = Batch.from_data_list(all_data).to(device)
all_embeddings = encoder(big_batch.x, big_batch.edge_index,
big_batch.batch, time_features_tensor)
Code reference : TemporalDataLoader.py:194-212
Sequence Reshaping
Embeddings are reshaped back into per-subject sequences based on visit ordering.
Padding
Variable-length sequences are padded to the same length for batch processing. graph_seq, lengths, labels = self ._pad_sequences(batch_sequences, subject_labels)
Code reference : TemporalDataLoader.py:226
Efficiency Gain : Batched encoding processes hundreds of graphs simultaneously rather than sequentially, dramatically reducing computation time.
Training Configuration
Model Selection (main.py:38-52)
parser.add_argument( '--model_type' , type = str , default = 'LSTM' )
parser.add_argument( '--layer_type' , type = str , default = "GraphSAGE" )
parser.add_argument( '--gnn_hidden_dim' , type = int , default = 256 )
parser.add_argument( '--gnn_num_layers' , type = int , default = 2 )
parser.add_argument( '--gnn_activation' , type = str , default = 'elu' )
parser.add_argument( '--lstm_hidden_dim' , type = int , default = 64 )
parser.add_argument( '--lstm_num_layers' , type = int , default = 1 )
parser.add_argument( '--lstm_bidirectional' , type = bool , default = True )
parser.add_argument( '--freeze_encoder' , action = 'store_true' )
parser.add_argument( '--use_topk_pooling' , action = 'store_true' , default = True )
parser.add_argument( '--topk_ratio' , type = float , default = 0.3 )
parser.add_argument( '--use_time_features' , action = 'store_true' )
Encoder Freezing
When --freeze_encoder is set, the GNN weights are frozen during temporal training:
if opt.freeze_encoder:
for param in encoder.parameters():
param.requires_grad = False
This is useful when:
The encoder is pre-trained on large datasets
You want to prevent catastrophic forgetting
Temporal training data is limited
Next Steps
GNN Details Layer types and graph operations
Temporal Modeling RNN architecture comparisons
Spatiotemporal Integration How spatial and temporal stages combine