Overview
Graph Neural Networks (GNNs) are the spatial encoder in the STGNN architecture, transforming brain connectivity graphs into fixed-dimensional embeddings that capture connectivity patterns.
Supported Layer Types
The system supports three state-of-the-art GNN architectures, selectable via --layer_type:
GraphSAGE Default choice Graph Sample and Aggregate - efficient neighbor sampling and aggregation
GCN Graph Convolutional Network - spectral-based convolution operation
GAT Graph Attention Network - learned attention weights for neighbors
GraphSAGE (Default)
Architecture
from torch_geometric.nn import SAGEConv
if layer_type == "GraphSAGE" :
conv = SAGEConv(in_dim, out_dim)
Code reference : model.py:60-61
How It Works
Neighborhood Sampling
For each node, sample a fixed number of neighbors (or use all neighbors)
Aggregation
Aggregate neighbor features using mean pooling: h N ( v ) ( l ) = MEAN ( { h u ( l − 1 ) : u ∈ N ( v ) } ) h_{\mathcal{N}(v)}^{(l)} = \text{MEAN}(\{h_u^{(l-1)} : u \in \mathcal{N}(v)\}) h N ( v ) ( l ) = MEAN ({ h u ( l − 1 ) : u ∈ N ( v )})
Concatenation & Transformation
Combine node’s own features with aggregated neighborhood: h v ( l ) = σ ( W ⋅ [ h v ( l − 1 ) ∥ h N ( v ) ( l ) ] ) h_v^{(l)} = \sigma(W \cdot [h_v^{(l-1)} \| h_{\mathcal{N}(v)}^{(l)}]) h v ( l ) = σ ( W ⋅ [ h v ( l − 1 ) ∥ h N ( v ) ( l ) ])
Advantages
Mean aggregation is fast and memory-efficient, suitable for large graphs with many nodes (e.g., 100-400 brain regions).
Can handle varying graph sizes and structures without modification. Works well with different brain parcellation schemes.
Learns a function to generate embeddings rather than embedding lookup, allowing generalization to unseen graph structures.
When to Use
Best for : Large brain graphs (200+ regions), consistent performance, default choice for most applications
Graph Convolutional Network (GCN)
Architecture
from torch_geometric.nn import GCNConv
if layer_type == "GCN" :
conv = GCNConv(in_dim, out_dim)
Code reference : model.py:56-57
How It Works
Normalized Aggregation
Aggregate with symmetric normalization: h v ( l ) = σ ( ∑ u ∈ N ( v ) ∪ { v } 1 d u d v W h u ( l − 1 ) ) h_v^{(l)} = \sigma\left(\sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{d_u d_v}} W h_u^{(l-1)}\right) h v ( l ) = σ u ∈ N ( v ) ∪ { v } ∑ d u d v 1 W h u ( l − 1 ) where d u d_u d u and d v d_v d v are node degrees.
Spectral Foundation
Based on spectral graph theory - approximates graph convolution in the frequency domain.
Advantages
Strong mathematical grounding in spectral graph theory, with provable properties.
Automatic handling of varying node degrees - high-degree nodes don’t dominate.
Clean, elegant formulation with fewer hyperparameters.
When to Use
Best for : Well-understood theoretical properties needed, simpler architecture preferred, smaller graphs (< 200 regions)
Graph Attention Network (GAT)
Architecture
from torch_geometric.nn import GATConv
if layer_type == "GAT" :
conv = GATConv(in_dim, out_dim)
Code reference : model.py:58-59
How It Works
Attention Coefficient Computation
Learn importance of each neighbor via attention mechanism: e v u = LeakyReLU ( a T [ W h v ∥ W h u ] ) e_{vu} = \text{LeakyReLU}(a^T [W h_v \| W h_u]) e vu = LeakyReLU ( a T [ W h v ∥ W h u ])
Softmax Normalization
Normalize attention across neighbors: α v u = exp ( e v u ) ∑ k ∈ N ( v ) exp ( e v k ) \alpha_{vu} = \frac{\exp(e_{vu})}{\sum_{k \in \mathcal{N}(v)} \exp(e_{vk})} α vu = ∑ k ∈ N ( v ) exp ( e v k ) exp ( e vu )
Weighted Aggregation
Aggregate using learned attention weights: h v ( l ) = σ ( ∑ u ∈ N ( v ) α v u W h u ( l − 1 ) ) h_v^{(l)} = \sigma\left(\sum_{u \in \mathcal{N}(v)} \alpha_{vu} W h_u^{(l-1)}\right) h v ( l ) = σ u ∈ N ( v ) ∑ α vu W h u ( l − 1 )
Advantages
Learns which connections matter most for each node, rather than treating all neighbors equally.
Attention weights can be visualized to understand which brain connections the model focuses on.
Can learn multiple attention patterns simultaneously (though default uses single head).
When to Use
Best for : Interpretability needed, heterogeneous connectivity importance, when some edges are more critical than others
Computational Cost : Higher memory and computation than GCN/GraphSAGE due to attention mechanism. May be slower for very large graphs.
Layer Configuration
Dynamic Layer Construction
Layers are built dynamically based on configuration:
for i in range (num_layers):
if i == 0 :
in_dim = input_dim # First layer: input dimension
out_dim = hidden_dim
elif i == num_layers - 1 :
in_dim = hidden_dim
out_dim = output_dim # Last layer: output dimension
else :
in_dim = hidden_dim
out_dim = hidden_dim # Middle layers: hidden dimension
# Create layer based on type
if layer_type == "GCN" :
conv = GCNConv(in_dim, out_dim)
elif layer_type == "GAT" :
conv = GATConv(in_dim, out_dim)
elif layer_type == "GraphSAGE" :
conv = SAGEConv(in_dim, out_dim)
self .convs.append(conv)
self .gns.append(GraphNorm(out_dim))
Code reference : model.py:41-66
Default Configuration
--layer_type GraphSAGE
--gnn_num_layers 2
--gnn_hidden_dim 256
--gnn_activation elu
This creates a 2-layer GraphSAGE network: input_dim → 256 → 256 (before pooling)
Layer Operations
Each GNN layer includes a standardized processing pipeline:
for i in range ( self .num_layers):
# 1. Graph convolution (message passing)
x = self .convs[i](x, edge_index)
# 2. Graph normalization (batch-wise)
x = self .gns[i](x, batch)
# 3. Activation function
x = self .activation_fn(x)
# 4. Dropout for regularization
x = self .dropout(x)
# 5. Optional: TopK pooling (hierarchical coarsening)
if self .use_topk_pooling:
x, edge_index, _, batch, perm, score = self .topk_pools[i](
x, edge_index, batch = batch
)
Code reference : model.py:139-146
Component Details
Graph Normalization
Activation Functions
Dropout
from torch_geometric.nn import GraphNorm
self .gns.append(GraphNorm(out_dim))
x = self .gns[i](x, batch)
Purpose : Normalizes node features within each graph (like BatchNorm but graph-aware)Benefits :
Stabilizes training
Reduces internal covariate shift
Allows higher learning rates
Code reference : model.py:66activation_map = {
'relu' : F.relu,
'leaky_relu' : F.leaky_relu,
'elu' : F.elu,
'gelu' : F.gelu
}
self .activation_fn = activation_map.get(activation, F.relu)
Options (via --gnn_activation):
ReLU : Fast, standard choice
LeakyReLU : Prevents dying neurons
ELU : Smooth, faster convergence (default)
GELU : Used in transformers, smooth approximation
Code reference : model.py:28-34self .dropout = nn.Dropout( p = dropout) # default p=0.5
x = self .dropout(x)
Purpose : Regularization to prevent overfittingDefault : 50% dropout rateWhen : Applied after each activation functionCode reference : model.py:69, 143
TopK Pooling
Hierarchical graph coarsening that retains the most important nodes:
from torch_geometric.nn import TopKPooling
if use_topk_pooling:
# Ensure minimum 30% retention
safe_ratio = max ( 0.3 , min ( 1.0 , topk_ratio))
for i in range (num_layers):
pool_dim = output_dim if i == num_layers - 1 else hidden_dim
self .topk_pools.append(
TopKPooling(pool_dim, ratio = safe_ratio, min_score = None )
)
Code reference : model.py:72-86
How It Works
Score Computation
Each node receives an importance score: s v = x v ⋅ p ∥ p ∥ s_v = \frac{x_v \cdot p}{\|p\|} s v = ∥ p ∥ x v ⋅ p where p p p is a learned projection vector.
Top-K Selection
Keep only the top k = ⌈ ratio × N ⌉ k = \lceil \text{ratio} \times N \rceil k = ⌈ ratio × N ⌉ nodes with highest scores.
Graph Reconstruction
Update edge index to only include edges between retained nodes.
Benefits
Progressively reduces graph size through layers, focusing computation on important regions.
Later layers operate on coarser graphs, capturing higher-level patterns.
Learns which brain regions are most diagnostically relevant.
Safety Mechanism
Minimum Retention : The system enforces at least 30% node retention to prevent empty graphs, overriding user-specified topk_ratio if needed (model.py:74).
safe_ratio = max ( 0.3 , min ( 1.0 , topk_ratio))
print ( f "TopK pooling initialized with safe ratio: { safe_ratio } (minimum 30%)" )
Global Pooling
After all GNN layers, graph-level features are extracted via dual pooling:
from torch_geometric.nn import global_mean_pool, global_max_pool
# Aggregate all node features to graph-level representation
x_mean = global_mean_pool(x, batch) # Average: [B, out_dim]
x_max = global_max_pool(x, batch) # Maximum: [B, out_dim]
# Concatenate for richer representation
x = torch.cat([x_mean, x_max], dim = 1 ) # [B, out_dim * 2]
Code reference : model.py:149-153, 167-169
Why Dual Pooling?
Mean Pooling Captures : Average connectivity patterns across all regionsPros : Smooth, robust to outliersRepresents : Overall brain state
Max Pooling Captures : Peak activations and extreme valuesPros : Sensitive to salient featuresRepresents : Most abnormal regions
Output Dimension : With output_dim=256, dual pooling produces 512D embeddings (256 from mean + 256 from max).
Layer Type Comparison
Layer Type Speed Memory Accuracy Interpretability Best Use Case GraphSAGE Fast Low High Medium Default, large graphs GCN Fast Low High Low Simpler architecture, theoretical GAT Slow High Highest High When interpretability matters
Depth Considerations
2 Layers (Default)
3 Layers
4-5 Layers
Receptive Field : 2-hop neighborhoodPros :
Fast training
Less overfitting risk
Sufficient for most brain parcellation schemes
Cons :
Limited long-range connectivity modeling
Receptive Field : 3-hop neighborhoodPros :
Captures network-level organization
Better for larger graphs (300+ regions)
Cons :
Slower training
More parameters
Receptive Field : 4-5 hop neighborhoodPros :
Maximum expressiveness
Whole-brain connectivity patterns
Cons :
Over-smoothing risk (nodes become too similar)
Overfitting on small datasets
Slower training
Practical Configuration Guide
Quick Start
# Default configuration (recommended)
python main.py \
--layer_type GraphSAGE \
--gnn_num_layers 2 \
--gnn_hidden_dim 256 \
--gnn_activation elu \
--use_topk_pooling \
--topk_ratio 0.3
For Interpretability
# Use GAT with attention weights
python main.py \
--layer_type GAT \
--gnn_num_layers 2 \
--gnn_hidden_dim 128 # Reduce dim to manage memory
For Large Graphs (400+ regions)
# Deeper network with aggressive pooling
python main.py \
--layer_type GraphSAGE \
--gnn_num_layers 3 \
--gnn_hidden_dim 256 \
--use_topk_pooling \
--topk_ratio 0.2 # Keep only 20% of nodes
For Small Datasets
# Simpler network to prevent overfitting
python main.py \
--layer_type GCN \
--gnn_num_layers 2 \
--gnn_hidden_dim 128 \
--dropout 0.6 # Higher dropout
Debugging and Visualization
Check Layer Dimensions : Add print statements to verify tensor shapes through the network:for i in range ( self .num_layers):
print ( f "Layer { i } : x.shape = { x.shape } , edges = { edge_index.shape[ 1 ] } " )
x = self .convs[i](x, edge_index)
Common Issues
Empty Graphs After TopK Pooling
Symptom : Runtime error about empty batchesCause : topk_ratio too aggressive for small graphsSolution : Increase topk_ratio or disable TopK pooling--topk_ratio 0.5 # Keep 50% of nodes
# OR
--use_topk_pooling False
Symptom : CUDA out of memory errorCause : GAT layers with large graphs, or too large hidden_dimSolution : Reduce dimensions or use GraphSAGE--layer_type GraphSAGE # More memory efficient
--gnn_hidden_dim 128 # Smaller dimension
--batch_size 8 # Smaller batches
Symptom : Loss plateaus early, low accuracyCause : Network too shallow or too deep, wrong activationSolution : Tune architecture hyperparameters--gnn_num_layers 3 # Try different depths
--gnn_activation elu # ELU often works better than ReLU
--lr 0.001 # Adjust learning rate
Next Steps
Temporal Modeling How GNN embeddings are processed temporally
Architecture Overview Complete system pipeline
Spatiotemporal Integration Combining spatial and temporal features