Overview
GraphNeuralNetwork is a flexible GNN encoder that processes individual brain FC graphs. It supports multiple layer types (GCN, GAT, GraphSAGE), optional TopK pooling, and time-aware features.
Class Signature
Parameters
Number of input node features (e.g., brain regions)
Hidden dimension for intermediate GNN layers
Output dimension for final layer (before pooling)
Dropout probability for regularization
Whether to use hierarchical TopK pooling instead of global pooling
Ratio of nodes to keep in TopK pooling (minimum 30% enforced)
Type of GNN layer:
"GCN", "GAT", or "GraphSAGE"Number of GNN convolutional layers
Activation function:
"relu", "leaky_relu", "elu", or "gelu"Whether to incorporate temporal gap features for time-aware prediction
Forward Method
Parameters
Node features of shape
[num_nodes, input_dim]Edge indices of shape
[2, num_edges]Batch assignment vector for nodes
Optional time gaps for prediction, shape
[batch_size] or [batch_size, 1]. Required if use_time_features=TrueReturns
Graph embedding of shape
[batch_size, output_dim * 2] (concatenated mean + max pooling)Example Usage
From main.py:177:Architecture Details
Layer Configuration
- First layer:
input_dim → hidden_dim - Middle layers:
hidden_dim → hidden_dim - Final layer:
hidden_dim → output_dim
Pooling Strategy
Withuse_topk_pooling=True:
- TopK pooling applied after each layer
- Safe ratio enforced (minimum 30% of nodes retained)
- Final output: concatenated global mean + max pooling
use_topk_pooling=False:
- Traditional global mean + max pooling on final layer only
Time Feature Integration
Whenuse_time_features=True:
- Time gaps projected to 32D embedding
- Concatenated with 512D graph features (256 × 2 from pooling)
- Fused through linear layer to produce final 512D embedding
Methods
load_state_dict_flexible
State dictionary to load
Whether to strictly enforce matching keys
Notes
- The encoder outputs 512D embeddings (256 × 2 from mean+max pooling)
- TopK pooling ratio is clamped to [0.3, 1.0] for stability
- Graph normalization applied after each convolutional layer
- Supports flexible architecture with 2-5 layers typical