Skip to main content

Architecture Overview

The STGNN model combines two neural network components:
  1. Graph Encoder (GraphNeuralNetwork): Extracts spatial features from brain connectivity graphs
  2. Temporal Predictor (LSTM/GRU/RNN): Models temporal progression patterns across visits

Graph Neural Network Encoder

Dynamic Layer Architecture

The GNN encoder dynamically constructs layers based on num_layers parameter (model.py:40-66):
for i in range(num_layers):
    if i == 0:
        in_dim, out_dim = input_dim, hidden_dim      # 100 → 256
    elif i == num_layers - 1:
        in_dim, out_dim = hidden_dim, output_dim     # 256 → 256
    else:
        in_dim, out_dim = hidden_dim, hidden_dim     # 256 → 256
Default Configuration (2 layers):
  • Layer 1: 100 → 256 (input features to hidden)
  • Layer 2: 256 → 256 (hidden to output)
  • Final pooling: 256 × 2 = 512D (mean + max concatenation)

Layer Types

GraphSAGE Architecture

Aggregation Strategy: Sampling and aggregating neighborhood information
conv = SAGEConv(in_dim, out_dim)
Properties:
  • Inductive learning (generalizes to unseen nodes)
  • Mean aggregation of neighbor features
  • Efficient for large graphs
  • Best for variable graph structures
Use Case: Default choice for brain connectivity graphs with varying node importanceReference: model.py:60-61

Activation Functions

Configurable via --gnn_activation (model.py:27-34):
activation_fn = F.elu
Formula: f(x) = x if x > 0 else α(e^x - 1)Properties:
  • Smooth for negative values
  • Faster convergence than ReLU
  • Mean activation closer to zero
Default α: 1.0

Pooling Strategies

Hierarchical TopK Pooling

Applied after each GNN layer to select the most important nodes (model.py:73-87):
topk_pools = nn.ModuleList([
    TopKPooling(layer_dim, ratio=0.3) 
    for _ in range(num_layers)
])
Process:
  1. Compute importance scores for each node
  2. Keep top 30% (default) highest-scoring nodes
  3. Update edge_index to reflect pruned graph
  4. Apply global mean+max pooling on final nodes
Safeguards:
  • Minimum ratio clamped to 0.3 (30% retention)
  • Prevents empty graphs from over-aggressive pooling
Forward Pass (model.py:136-155):
for i in range(num_layers):
    x = self.convs[i](x, edge_index)
    x = self.gns[i](x, batch)
    x = self.activation_fn(x)
    x = self.dropout(x)
    x, edge_index, _, batch, _, _ = self.topk_pools[i](x, edge_index, batch=batch)

x_mean = global_mean_pool(x, batch)
x_max = global_max_pool(x, batch)
return torch.cat([x_mean, x_max], dim=1)  # [B, 512]
Enable: --use_topk_pooling (default: True)Configure: --topk_ratio 0.3

Time-Aware GNN (Optional)

When --use_time_features is enabled, the encoder incorporates temporal information:
if self.use_time_features:
    time_projection = nn.Sequential(
        nn.Linear(1, 16),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(16, 32)
    )
    
    fusion_layer = nn.Sequential(
        nn.Linear(512 + 32, 512),  # Graph (512) + Time (32)
        nn.ReLU(),
        nn.Dropout(dropout)
    )
Process (model.py:113-134):
  1. Standard GNN forward → 512D graph embedding
  2. Project time-to-predict (1D) → 32D time embedding
  3. Concatenate: 512D + 32D = 544D
  4. Fusion layer: 544D → 512D final embedding
Time features are kept small (32D) relative to graph features (512D) to prevent temporal information from dominating spatial brain patterns.

Temporal Predictors

Architecture Comparison

Long Short-Term Memory

Implementation: TemporalPredictor.py:24-29
lstm = nn.LSTM(
    input_size=512,              # Graph embedding dimension
    hidden_size=64,              # Hidden state size
    num_layers=1,
    batch_first=True,
    dropout=0.45,                # Applied if num_layers > 1
    bidirectional=True           # Optional
)
Components:
  • Input gate: Controls information flow into cell
  • Forget gate: Decides what to discard from cell state
  • Output gate: Controls hidden state output
  • Cell state: Long-term memory pathway
Packed Sequences (TemporalPredictor.py:66-71):
fused_packed = nn.utils.rnn.pack_padded_sequence(
    fused, lengths.cpu(), batch_first=True, enforce_sorted=False
)
output_packed, (h_n, c_n) = self.lstm(fused_packed)
Final Hidden State (bidirectional):
final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)  # [B, 128]
Classifier:
classifier = nn.Sequential(
    nn.Linear(128, 64),  # 64*2 if bidirectional
    nn.ReLU(),
    nn.Dropout(0.45),
    nn.Linear(64, 2)     # Binary classification
)
Best For: Long temporal sequences (5+ visits), complex progression patterns

Complete Pipeline

Training Flow

Batched Encoding (TemporalDataLoader.py:193-212)

# Collect all graphs from batch_size subjects
all_data = []
for subject in batch_subjects:
    for visit_idx in subject_visit_indices[subject]:
        all_data.append(dataset[visit_idx])

# Single forward pass for all graphs
big_batch = Batch.from_data_list(all_data).to(device)
encoder.eval()
with torch.no_grad():
    all_embeddings = encoder(
        big_batch.x,           # [N_total_nodes, 100]
        big_batch.edge_index,  # [2, E_total_edges]
        big_batch.batch,       # [N_total_nodes]
        time_features          # [N_graphs, 1] or None
    )  # → [N_graphs, 512]

# Reshape back into per-subject sequences
batch_sequences = []  # List of [T_i, 512] tensors
The encoder processes ALL graphs from multiple subjects in a single batch. With batch_size=16 and average 5 visits per subject, this means ~80 graphs per forward pass. Ensure sufficient GPU memory.

Sequence Classification (TemporalPredictor.py:66-85)

# Packed sequence processing
fused_packed = pack_padded_sequence(
    graph_seq,      # [B, T_max, 512]
    lengths,        # [B] actual sequence lengths
    batch_first=True,
    enforce_sorted=False
)

# LSTM processes only non-padded elements
output_packed, (h_n, c_n) = lstm(fused_packed)

# Extract final hidden state
if bidirectional:
    final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)  # [B, 128]
else:
    final_hidden = h_n[-1]  # [B, 64]

# Classification head
logits = classifier(final_hidden)  # [B, 2]

Model Sizes

Parameter Counts

Default Configuration:
  • GNN Encoder (GraphSAGE, 2 layers): ~330K parameters
  • LSTM (hidden=64, bidirectional): ~50K parameters
  • Classifier: ~8K parameters
  • Total: ~388K parameters
High-Capacity Configuration:
--gnn_hidden_dim 512 --gnn_num_layers 4 --lstm_hidden_dim 128 --lstm_num_layers 2
  • GNN: ~1.3M parameters
  • LSTM: ~200K parameters
  • Classifier: ~32K parameters
  • Total: ~1.5M parameters

Example Configurations

GraphSAGE-LSTM (Default)

python main.py \
  --layer_type GraphSAGE \
  --gnn_hidden_dim 256 \
  --gnn_num_layers 2 \
  --gnn_activation elu \
  --use_topk_pooling \
  --topk_ratio 0.3 \
  --model_type LSTM \
  --lstm_hidden_dim 64 \
  --lstm_num_layers 1 \
  --lstm_bidirectional

GCN-GRU (Lightweight)

python main.py \
  --layer_type GCN \
  --gnn_hidden_dim 128 \
  --gnn_num_layers 2 \
  --model_type GRU \
  --lstm_hidden_dim 32 \
  --batch_size 32

GAT-LSTM (High Capacity)

python main.py \
  --layer_type GAT \
  --gnn_hidden_dim 512 \
  --gnn_num_layers 3 \
  --gnn_activation gelu \
  --model_type LSTM \
  --lstm_hidden_dim 128 \
  --lstm_num_layers 2 \
  --batch_size 8
When using GAT with high hidden dimensions, reduce batch size to avoid out-of-memory errors. GAT requires more memory due to attention computation.

Build docs developers (and LLMs) love