Architecture Overview
The STGNN model combines two neural network components:
- Graph Encoder (
GraphNeuralNetwork): Extracts spatial features from brain connectivity graphs
- Temporal Predictor (LSTM/GRU/RNN): Models temporal progression patterns across visits
Graph Neural Network Encoder
Dynamic Layer Architecture
The GNN encoder dynamically constructs layers based on num_layers parameter (model.py:40-66):
for i in range(num_layers):
if i == 0:
in_dim, out_dim = input_dim, hidden_dim # 100 → 256
elif i == num_layers - 1:
in_dim, out_dim = hidden_dim, output_dim # 256 → 256
else:
in_dim, out_dim = hidden_dim, hidden_dim # 256 → 256
Default Configuration (2 layers):
- Layer 1: 100 → 256 (input features to hidden)
- Layer 2: 256 → 256 (hidden to output)
- Final pooling: 256 × 2 = 512D (mean + max concatenation)
Layer Types
GraphSAGE (Default)
GCN
GAT
GraphSAGE Architecture
Aggregation Strategy: Sampling and aggregating neighborhood informationconv = SAGEConv(in_dim, out_dim)
Properties:
- Inductive learning (generalizes to unseen nodes)
- Mean aggregation of neighbor features
- Efficient for large graphs
- Best for variable graph structures
Use Case: Default choice for brain connectivity graphs with varying node importanceReference: model.py:60-61Graph Convolutional Network
Convolution Strategy: Spectral graph convolutions via normalized adjacencyconv = GCNConv(in_dim, out_dim)
Properties:
- Symmetric normalization: Ã = D^(-1/2) A D^(-1/2)
- Equal importance to all edges
- Transductive learning (requires full graph)
- Faster computation than GAT
Use Case: When all brain connections should be treated uniformlyFormula: H^(l+1) = σ(ÃH^(l)W^(l))Reference: model.py:56-57Graph Attention Network
Attention Strategy: Learns importance weights for each connectionconv = GATConv(in_dim, out_dim)
Properties:
- Attention mechanism over neighbors
- Learns edge importance dynamically
- Multi-head attention support
- Highest model capacity
Use Case: When different brain connections have varying importanceFormula: α_ij = softmax(LeakyReLU(a^T [Wh_i || Wh_j]))Note: Higher memory usage and slower trainingReference: model.py:58-59
Activation Functions
Configurable via --gnn_activation (model.py:27-34):
ELU (Default)
ReLU
Leaky ReLU
GELU
Formula: f(x) = x if x > 0 else α(e^x - 1)Properties:
- Smooth for negative values
- Faster convergence than ReLU
- Mean activation closer to zero
Default α: 1.0 Formula: f(x) = max(0, x)Properties:
- Standard activation
- Fast computation
- Can suffer from dead neurons
activation_fn = F.leaky_relu
Formula: f(x) = x if x > 0 else αxProperties:
- Prevents dead neurons
- Small gradient for negatives
Default α: 0.01Formula: f(x) = x · Φ(x)Properties:
- Smooth, probabilistic
- Used in transformers
- Better for deep networks
Pooling Strategies
TopK Pooling (Default)
Global Pooling
Hierarchical TopK Pooling
Applied after each GNN layer to select the most important nodes (model.py:73-87):topk_pools = nn.ModuleList([
TopKPooling(layer_dim, ratio=0.3)
for _ in range(num_layers)
])
Process:
- Compute importance scores for each node
- Keep top 30% (default) highest-scoring nodes
- Update edge_index to reflect pruned graph
- Apply global mean+max pooling on final nodes
Safeguards:
- Minimum ratio clamped to 0.3 (30% retention)
- Prevents empty graphs from over-aggressive pooling
Forward Pass (model.py:136-155):for i in range(num_layers):
x = self.convs[i](x, edge_index)
x = self.gns[i](x, batch)
x = self.activation_fn(x)
x = self.dropout(x)
x, edge_index, _, batch, _, _ = self.topk_pools[i](x, edge_index, batch=batch)
x_mean = global_mean_pool(x, batch)
x_max = global_max_pool(x, batch)
return torch.cat([x_mean, x_max], dim=1) # [B, 512]
Enable: --use_topk_pooling (default: True)Configure: --topk_ratio 0.3Traditional Global Pooling
No node selection - pools all nodes after final layer (model.py:157-171):for i in range(num_layers):
x = self.convs[i](x, edge_index)
x = self.gns[i](x, batch)
x = self.activation_fn(x)
x = self.dropout(x)
x_mean = global_mean_pool(x, batch) # [B, 256]
x_max = global_max_pool(x, batch) # [B, 256]
return torch.cat([x_mean, x_max], dim=1) # [B, 512]
Operations:
- Mean pooling: Average of all node features
- Max pooling: Maximum activation per feature
- Concatenation: Combines both for richer representation
Enable: Omit --use_topk_pooling flag
Time-Aware GNN (Optional)
When --use_time_features is enabled, the encoder incorporates temporal information:
if self.use_time_features:
time_projection = nn.Sequential(
nn.Linear(1, 16),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(16, 32)
)
fusion_layer = nn.Sequential(
nn.Linear(512 + 32, 512), # Graph (512) + Time (32)
nn.ReLU(),
nn.Dropout(dropout)
)
Process (model.py:113-134):
- Standard GNN forward → 512D graph embedding
- Project time-to-predict (1D) → 32D time embedding
- Concatenate: 512D + 32D = 544D
- Fusion layer: 544D → 512D final embedding
Time features are kept small (32D) relative to graph features (512D) to prevent temporal information from dominating spatial brain patterns.
Temporal Predictors
Architecture Comparison
Long Short-Term Memory
Implementation: TemporalPredictor.py:24-29lstm = nn.LSTM(
input_size=512, # Graph embedding dimension
hidden_size=64, # Hidden state size
num_layers=1,
batch_first=True,
dropout=0.45, # Applied if num_layers > 1
bidirectional=True # Optional
)
Components:
- Input gate: Controls information flow into cell
- Forget gate: Decides what to discard from cell state
- Output gate: Controls hidden state output
- Cell state: Long-term memory pathway
Packed Sequences (TemporalPredictor.py:66-71):fused_packed = nn.utils.rnn.pack_padded_sequence(
fused, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output_packed, (h_n, c_n) = self.lstm(fused_packed)
Final Hidden State (bidirectional):final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1) # [B, 128]
Classifier:classifier = nn.Sequential(
nn.Linear(128, 64), # 64*2 if bidirectional
nn.ReLU(),
nn.Dropout(0.45),
nn.Linear(64, 2) # Binary classification
)
Best For: Long temporal sequences (5+ visits), complex progression patternsGated Recurrent Unit
Implementation: GRUPredictor.py:23-28gru = nn.GRU(
input_size=512,
hidden_size=64,
num_layers=1,
batch_first=True,
dropout=0.45,
bidirectional=False # Note: GRU is unidirectional
)
Components:
- Update gate: Decides how much past to keep
- Reset gate: Determines relevance of past
- No separate cell state (simpler than LSTM)
Forward Pass (GRUPredictor.py:47-56):fused_packed = nn.utils.rnn.pack_padded_sequence(
fused, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output_packed, h_n = self.gru(fused_packed)
final_hidden = h_n[-1] # [B, 64]
logits = self.classifier(final_hidden)
Advantages:
- Fewer parameters than LSTM (faster training)
- Often comparable performance
- Less prone to overfitting
Best For: Medium-length sequences (3-7 visits), faster experimentationVanilla Recurrent Neural Network
Implementation: RNNPredictor.py:23-29rnn = nn.RNN(
input_size=512,
hidden_size=64,
num_layers=1,
batch_first=True,
dropout=0.45,
bidirectional=False,
nonlinearity='tanh' # or 'relu'
)
Components:
- Single hidden state (no gates)
- Simple recurrence: h_t = tanh(W_ih x_t + W_hh h_)
Forward Pass (RNNPredictor.py:48-60):fused_packed = nn.utils.rnn.pack_padded_sequence(
fused, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output_packed, h_n = self.rnn(fused_packed)
final_hidden = h_n[-1] # [B, 64]
logits = self.classifier(final_hidden)
Limitations:
- Vanishing/exploding gradients on long sequences
- No gating mechanism
Best For: Short sequences (2-3 visits), baseline comparisons
Complete Pipeline
Training Flow
Batched Encoding (TemporalDataLoader.py:193-212)
# Collect all graphs from batch_size subjects
all_data = []
for subject in batch_subjects:
for visit_idx in subject_visit_indices[subject]:
all_data.append(dataset[visit_idx])
# Single forward pass for all graphs
big_batch = Batch.from_data_list(all_data).to(device)
encoder.eval()
with torch.no_grad():
all_embeddings = encoder(
big_batch.x, # [N_total_nodes, 100]
big_batch.edge_index, # [2, E_total_edges]
big_batch.batch, # [N_total_nodes]
time_features # [N_graphs, 1] or None
) # → [N_graphs, 512]
# Reshape back into per-subject sequences
batch_sequences = [] # List of [T_i, 512] tensors
The encoder processes ALL graphs from multiple subjects in a single batch. With batch_size=16 and average 5 visits per subject, this means ~80 graphs per forward pass. Ensure sufficient GPU memory.
Sequence Classification (TemporalPredictor.py:66-85)
# Packed sequence processing
fused_packed = pack_padded_sequence(
graph_seq, # [B, T_max, 512]
lengths, # [B] actual sequence lengths
batch_first=True,
enforce_sorted=False
)
# LSTM processes only non-padded elements
output_packed, (h_n, c_n) = lstm(fused_packed)
# Extract final hidden state
if bidirectional:
final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1) # [B, 128]
else:
final_hidden = h_n[-1] # [B, 64]
# Classification head
logits = classifier(final_hidden) # [B, 2]
Model Sizes
Parameter Counts
Default Configuration:
- GNN Encoder (GraphSAGE, 2 layers): ~330K parameters
- LSTM (hidden=64, bidirectional): ~50K parameters
- Classifier: ~8K parameters
- Total: ~388K parameters
High-Capacity Configuration:
--gnn_hidden_dim 512 --gnn_num_layers 4 --lstm_hidden_dim 128 --lstm_num_layers 2
- GNN: ~1.3M parameters
- LSTM: ~200K parameters
- Classifier: ~32K parameters
- Total: ~1.5M parameters
Example Configurations
GraphSAGE-LSTM (Default)
python main.py \
--layer_type GraphSAGE \
--gnn_hidden_dim 256 \
--gnn_num_layers 2 \
--gnn_activation elu \
--use_topk_pooling \
--topk_ratio 0.3 \
--model_type LSTM \
--lstm_hidden_dim 64 \
--lstm_num_layers 1 \
--lstm_bidirectional
GCN-GRU (Lightweight)
python main.py \
--layer_type GCN \
--gnn_hidden_dim 128 \
--gnn_num_layers 2 \
--model_type GRU \
--lstm_hidden_dim 32 \
--batch_size 32
GAT-LSTM (High Capacity)
python main.py \
--layer_type GAT \
--gnn_hidden_dim 512 \
--gnn_num_layers 3 \
--gnn_activation gelu \
--model_type LSTM \
--lstm_hidden_dim 128 \
--lstm_num_layers 2 \
--batch_size 8
When using GAT with high hidden dimensions, reduce batch size to avoid out-of-memory errors. GAT requires more memory due to attention computation.