Skip to main content

Overview

Graph Neural Networks (GNNs) are the spatial encoder in the STGNN architecture, transforming brain connectivity graphs into fixed-dimensional embeddings that capture connectivity patterns.

Supported Layer Types

The system supports three state-of-the-art GNN architectures, selectable via --layer_type:

GraphSAGE

Default choiceGraph Sample and Aggregate - efficient neighbor sampling and aggregation

GCN

Graph Convolutional Network - spectral-based convolution operation

GAT

Graph Attention Network - learned attention weights for neighbors

GraphSAGE (Default)

Architecture

from torch_geometric.nn import SAGEConv

if layer_type == "GraphSAGE":
    conv = SAGEConv(in_dim, out_dim)
Code reference: model.py:60-61

How It Works

1

Neighborhood Sampling

For each node, sample a fixed number of neighbors (or use all neighbors)
2

Aggregation

Aggregate neighbor features using mean pooling:hN(v)(l)=MEAN({hu(l1):uN(v)})h_{\mathcal{N}(v)}^{(l)} = \text{MEAN}(\{h_u^{(l-1)} : u \in \mathcal{N}(v)\})
3

Concatenation & Transformation

Combine node’s own features with aggregated neighborhood:hv(l)=σ(W[hv(l1)hN(v)(l)])h_v^{(l)} = \sigma(W \cdot [h_v^{(l-1)} \| h_{\mathcal{N}(v)}^{(l)}])

Advantages

Mean aggregation is fast and memory-efficient, suitable for large graphs with many nodes (e.g., 100-400 brain regions).
Can handle varying graph sizes and structures without modification. Works well with different brain parcellation schemes.
Learns a function to generate embeddings rather than embedding lookup, allowing generalization to unseen graph structures.

When to Use

Best for: Large brain graphs (200+ regions), consistent performance, default choice for most applications

Graph Convolutional Network (GCN)

Architecture

from torch_geometric.nn import GCNConv

if layer_type == "GCN":
    conv = GCNConv(in_dim, out_dim)
Code reference: model.py:56-57

How It Works

1

Normalized Aggregation

Aggregate with symmetric normalization:hv(l)=σ(uN(v){v}1dudvWhu(l1))h_v^{(l)} = \sigma\left(\sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{d_u d_v}} W h_u^{(l-1)}\right)where dud_u and dvd_v are node degrees.
2

Spectral Foundation

Based on spectral graph theory - approximates graph convolution in the frequency domain.

Advantages

Strong mathematical grounding in spectral graph theory, with provable properties.
Automatic handling of varying node degrees - high-degree nodes don’t dominate.
Clean, elegant formulation with fewer hyperparameters.

When to Use

Best for: Well-understood theoretical properties needed, simpler architecture preferred, smaller graphs (< 200 regions)

Graph Attention Network (GAT)

Architecture

from torch_geometric.nn import GATConv

if layer_type == "GAT":
    conv = GATConv(in_dim, out_dim)
Code reference: model.py:58-59

How It Works

1

Attention Coefficient Computation

Learn importance of each neighbor via attention mechanism:evu=LeakyReLU(aT[WhvWhu])e_{vu} = \text{LeakyReLU}(a^T [W h_v \| W h_u])
2

Softmax Normalization

Normalize attention across neighbors:αvu=exp(evu)kN(v)exp(evk)\alpha_{vu} = \frac{\exp(e_{vu})}{\sum_{k \in \mathcal{N}(v)} \exp(e_{vk})}
3

Weighted Aggregation

Aggregate using learned attention weights:hv(l)=σ(uN(v)αvuWhu(l1))h_v^{(l)} = \sigma\left(\sum_{u \in \mathcal{N}(v)} \alpha_{vu} W h_u^{(l-1)}\right)

Advantages

Learns which connections matter most for each node, rather than treating all neighbors equally.
Attention weights can be visualized to understand which brain connections the model focuses on.
Can learn multiple attention patterns simultaneously (though default uses single head).

When to Use

Best for: Interpretability needed, heterogeneous connectivity importance, when some edges are more critical than others
Computational Cost: Higher memory and computation than GCN/GraphSAGE due to attention mechanism. May be slower for very large graphs.

Layer Configuration

Dynamic Layer Construction

Layers are built dynamically based on configuration:
for i in range(num_layers):
    if i == 0:
        in_dim = input_dim      # First layer: input dimension
        out_dim = hidden_dim
    elif i == num_layers - 1:
        in_dim = hidden_dim
        out_dim = output_dim    # Last layer: output dimension
    else:
        in_dim = hidden_dim
        out_dim = hidden_dim    # Middle layers: hidden dimension
    
    # Create layer based on type
    if layer_type == "GCN":
        conv = GCNConv(in_dim, out_dim)
    elif layer_type == "GAT":
        conv = GATConv(in_dim, out_dim)
    elif layer_type == "GraphSAGE":
        conv = SAGEConv(in_dim, out_dim)
    
    self.convs.append(conv)
    self.gns.append(GraphNorm(out_dim))
Code reference: model.py:41-66

Default Configuration

--layer_type GraphSAGE
--gnn_num_layers 2
--gnn_hidden_dim 256
--gnn_activation elu
This creates a 2-layer GraphSAGE network: input_dim → 256 → 256 (before pooling)

Layer Operations

Each GNN layer includes a standardized processing pipeline:
for i in range(self.num_layers):
    # 1. Graph convolution (message passing)
    x = self.convs[i](x, edge_index)
    
    # 2. Graph normalization (batch-wise)
    x = self.gns[i](x, batch)
    
    # 3. Activation function
    x = self.activation_fn(x)
    
    # 4. Dropout for regularization
    x = self.dropout(x)
    
    # 5. Optional: TopK pooling (hierarchical coarsening)
    if self.use_topk_pooling:
        x, edge_index, _, batch, perm, score = self.topk_pools[i](
            x, edge_index, batch=batch
        )
Code reference: model.py:139-146

Component Details

from torch_geometric.nn import GraphNorm

self.gns.append(GraphNorm(out_dim))
x = self.gns[i](x, batch)
Purpose: Normalizes node features within each graph (like BatchNorm but graph-aware)Benefits:
  • Stabilizes training
  • Reduces internal covariate shift
  • Allows higher learning rates
Code reference: model.py:66

TopK Pooling

Hierarchical graph coarsening that retains the most important nodes:
from torch_geometric.nn import TopKPooling

if use_topk_pooling:
    # Ensure minimum 30% retention
    safe_ratio = max(0.3, min(1.0, topk_ratio))
    
    for i in range(num_layers):
        pool_dim = output_dim if i == num_layers - 1 else hidden_dim
        self.topk_pools.append(
            TopKPooling(pool_dim, ratio=safe_ratio, min_score=None)
        )
Code reference: model.py:72-86

How It Works

1

Score Computation

Each node receives an importance score:sv=xvpps_v = \frac{x_v \cdot p}{\|p\|}where pp is a learned projection vector.
2

Top-K Selection

Keep only the top k=ratio×Nk = \lceil \text{ratio} \times N \rceil nodes with highest scores.
3

Graph Reconstruction

Update edge index to only include edges between retained nodes.

Benefits

Progressively reduces graph size through layers, focusing computation on important regions.
Later layers operate on coarser graphs, capturing higher-level patterns.
Learns which brain regions are most diagnostically relevant.

Safety Mechanism

Minimum Retention: The system enforces at least 30% node retention to prevent empty graphs, overriding user-specified topk_ratio if needed (model.py:74).
safe_ratio = max(0.3, min(1.0, topk_ratio))
print(f"TopK pooling initialized with safe ratio: {safe_ratio} (minimum 30%)")

Global Pooling

After all GNN layers, graph-level features are extracted via dual pooling:
from torch_geometric.nn import global_mean_pool, global_max_pool

# Aggregate all node features to graph-level representation
x_mean = global_mean_pool(x, batch)  # Average: [B, out_dim]
x_max = global_max_pool(x, batch)    # Maximum: [B, out_dim]

# Concatenate for richer representation
x = torch.cat([x_mean, x_max], dim=1)  # [B, out_dim * 2]
Code reference: model.py:149-153, 167-169

Why Dual Pooling?

Mean Pooling

Captures: Average connectivity patterns across all regionsPros: Smooth, robust to outliersRepresents: Overall brain state

Max Pooling

Captures: Peak activations and extreme valuesPros: Sensitive to salient featuresRepresents: Most abnormal regions
Output Dimension: With output_dim=256, dual pooling produces 512D embeddings (256 from mean + 256 from max).

Comparative Performance

Layer Type Comparison

Layer TypeSpeedMemoryAccuracyInterpretabilityBest Use Case
GraphSAGEFastLowHighMediumDefault, large graphs
GCNFastLowHighLowSimpler architecture, theoretical
GATSlowHighHighestHighWhen interpretability matters

Depth Considerations

Receptive Field: 2-hop neighborhoodPros:
  • Fast training
  • Less overfitting risk
  • Sufficient for most brain parcellation schemes
Cons:
  • Limited long-range connectivity modeling

Practical Configuration Guide

Quick Start

# Default configuration (recommended)
python main.py \
  --layer_type GraphSAGE \
  --gnn_num_layers 2 \
  --gnn_hidden_dim 256 \
  --gnn_activation elu \
  --use_topk_pooling \
  --topk_ratio 0.3

For Interpretability

# Use GAT with attention weights
python main.py \
  --layer_type GAT \
  --gnn_num_layers 2 \
  --gnn_hidden_dim 128  # Reduce dim to manage memory

For Large Graphs (400+ regions)

# Deeper network with aggressive pooling
python main.py \
  --layer_type GraphSAGE \
  --gnn_num_layers 3 \
  --gnn_hidden_dim 256 \
  --use_topk_pooling \
  --topk_ratio 0.2  # Keep only 20% of nodes

For Small Datasets

# Simpler network to prevent overfitting
python main.py \
  --layer_type GCN \
  --gnn_num_layers 2 \
  --gnn_hidden_dim 128 \
  --dropout 0.6  # Higher dropout

Debugging and Visualization

Check Layer Dimensions: Add print statements to verify tensor shapes through the network:
for i in range(self.num_layers):
    print(f"Layer {i}: x.shape = {x.shape}, edges = {edge_index.shape[1]}")
    x = self.convs[i](x, edge_index)

Common Issues

Symptom: Runtime error about empty batchesCause: topk_ratio too aggressive for small graphsSolution: Increase topk_ratio or disable TopK pooling
--topk_ratio 0.5  # Keep 50% of nodes
# OR
--use_topk_pooling False
Symptom: CUDA out of memory errorCause: GAT layers with large graphs, or too large hidden_dimSolution: Reduce dimensions or use GraphSAGE
--layer_type GraphSAGE  # More memory efficient
--gnn_hidden_dim 128    # Smaller dimension
--batch_size 8          # Smaller batches
Symptom: Loss plateaus early, low accuracyCause: Network too shallow or too deep, wrong activationSolution: Tune architecture hyperparameters
--gnn_num_layers 3      # Try different depths
--gnn_activation elu    # ELU often works better than ReLU
--lr 0.001              # Adjust learning rate

Next Steps

Temporal Modeling

How GNN embeddings are processed temporally

Architecture Overview

Complete system pipeline

Spatiotemporal Integration

Combining spatial and temporal features

Build docs developers (and LLMs) love