Skip to main content

Overview

GraphNeuralNetwork is a flexible GNN encoder that processes individual brain FC graphs. It supports multiple layer types (GCN, GAT, GraphSAGE), optional TopK pooling, and time-aware features.

Class Signature

class GraphNeuralNetwork(nn.Module):
    def __init__(
        self,
        input_dim=100,
        hidden_dim=128,
        output_dim=256,
        dropout=0.5,
        use_topk_pooling=True,
        topk_ratio=0.5,
        layer_type="GCN",
        num_layers=3,
        activation='relu',
        use_time_features=False
    )

Parameters

input_dim
int
default:"100"
Number of input node features (e.g., brain regions)
hidden_dim
int
default:"128"
Hidden dimension for intermediate GNN layers
output_dim
int
default:"256"
Output dimension for final layer (before pooling)
dropout
float
default:"0.5"
Dropout probability for regularization
use_topk_pooling
bool
default:"True"
Whether to use hierarchical TopK pooling instead of global pooling
topk_ratio
float
default:"0.5"
Ratio of nodes to keep in TopK pooling (minimum 30% enforced)
layer_type
str
default:"GCN"
Type of GNN layer: "GCN", "GAT", or "GraphSAGE"
num_layers
int
default:"3"
Number of GNN convolutional layers
activation
str
default:"relu"
Activation function: "relu", "leaky_relu", "elu", or "gelu"
use_time_features
bool
default:"False"
Whether to incorporate temporal gap features for time-aware prediction

Forward Method

def forward(self, x, edge_index, batch, time_to_predict=None)

Parameters

x
torch.Tensor
Node features of shape [num_nodes, input_dim]
edge_index
torch.Tensor
Edge indices of shape [2, num_edges]
batch
torch.Tensor
Batch assignment vector for nodes
time_to_predict
torch.Tensor
default:"None"
Optional time gaps for prediction, shape [batch_size] or [batch_size, 1]. Required if use_time_features=True

Returns

embedding
torch.Tensor
Graph embedding of shape [batch_size, output_dim * 2] (concatenated mean + max pooling)

Example Usage

From main.py:177:
encoder = GraphNeuralNetwork(
    input_dim=100,
    hidden_dim=opt.gnn_hidden_dim,  # 256
    output_dim=256,
    dropout=0.2,
    use_topk_pooling=opt.use_topk_pooling,
    topk_ratio=opt.topk_ratio,  # 0.3
    layer_type=opt.layer_type,  # "GraphSAGE"
    num_layers=opt.gnn_num_layers,  # 2
    activation=opt.gnn_activation,  # "elu"
    use_time_features=opt.use_time_features
).to(device)

Architecture Details

Layer Configuration

  • First layer: input_dim → hidden_dim
  • Middle layers: hidden_dim → hidden_dim
  • Final layer: hidden_dim → output_dim

Pooling Strategy

With use_topk_pooling=True:
  • TopK pooling applied after each layer
  • Safe ratio enforced (minimum 30% of nodes retained)
  • Final output: concatenated global mean + max pooling
With use_topk_pooling=False:
  • Traditional global mean + max pooling on final layer only

Time Feature Integration

When use_time_features=True:
  1. Time gaps projected to 32D embedding
  2. Concatenated with 512D graph features (256 × 2 from pooling)
  3. Fused through linear layer to produce final 512D embedding

Methods

load_state_dict_flexible

def load_state_dict_flexible(self, state_dict, strict=False)
Flexible loading of pretrained weights with support for missing TopK pooling layers.
state_dict
dict
State dictionary to load
strict
bool
default:"False"
Whether to strictly enforce matching keys

Notes

  • The encoder outputs 512D embeddings (256 × 2 from mean+max pooling)
  • TopK pooling ratio is clamped to [0.3, 1.0] for stability
  • Graph normalization applied after each convolutional layer
  • Supports flexible architecture with 2-5 layers typical

Build docs developers (and LLMs) love