Skip to main content

Overview

DynamicGraphNeuralNetwork extends static GNN architectures to handle temporal sequences of graphs. It includes built-in temporal aggregation methods (mean, max, GRU) and end-to-end classification.

Class Signature

class DynamicGraphNeuralNetwork(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim=128,
        output_dim=256,
        num_classes=3,
        dropout=0.5,
        use_topk_pooling=True,
        topk_ratio=0.5,
        layer_type="GCN",
        temporal_aggregation="mean",
        num_layers=3,
        activation='relu'
    )

Parameters

input_dim
int
required
Raw node feature dimensionality (e.g., number of brain regions)
hidden_dim
int
default:"128"
Hidden dimension for GNN layers after input projection
output_dim
int
default:"256"
Output dimension of final GNN layer (before pooling)
num_classes
int
default:"3"
Number of classification classes
dropout
float
default:"0.5"
Dropout probability
use_topk_pooling
bool
default:"True"
Whether to use TopK pooling layers
topk_ratio
float
default:"0.5"
Ratio of nodes to retain in TopK pooling
layer_type
str
default:"GCN"
GNN layer type: "GCN", "GAT", or "GraphSAGE"
temporal_aggregation
str
default:"mean"
Temporal aggregation method: "mean", "max", or "gru"
num_layers
int
default:"3"
Number of GNN layers
activation
str
default:"relu"
Activation function: "relu", "leaky_relu", "elu", or "gelu"

Forward Method (Single Graph)

def forward(self, x, edge_index, batch, time_features=None)
Process a single graph and return embeddings. Compatible with TemporalDataLoader.

Parameters

x
torch.Tensor
Node features of shape [num_nodes, input_dim]
edge_index
torch.Tensor
Edge indices of shape [2, num_edges]
batch
torch.Tensor
Batch assignment for nodes
time_features
torch.Tensor
default:"None"
Optional time features (not used in this encoder)

Returns

embeddings
torch.Tensor
Graph embeddings of shape [batch_size, 2 * output_dim]

Forward Sequence Method

def forward_sequence(self, x_seq, edge_index_seq, batch_seq)
Process a sequence of graphs for classification.

Parameters

x_seq
list[torch.Tensor]
List of node feature tensors, each of shape [num_nodes_t, input_dim]
edge_index_seq
list[torch.Tensor]
List of edge index tensors per time step
batch_seq
list[torch.Tensor]
List of batch tensors per time step

Returns

logits
torch.Tensor
Classification logits of shape [batch_size, num_classes]

Architecture Details

Spatial Processing

  1. Input projection: input_dim → hidden_dim
  2. GNN layers with GraphNorm
  3. Optional TopK pooling after each layer
  4. Global mean + max pooling → 2 * output_dim

Temporal Aggregation

Mean aggregation: Average embeddings across time steps
out = time_outputs.mean(dim=1)
Max aggregation: Max-pool embeddings across time steps
out, _ = time_outputs.max(dim=1)
GRU aggregation: Process sequence through GRU
_, h = self.gru(time_outputs)
out = h.squeeze(0)

Classification

Final linear layer: temporal_dim → num_classes

Methods

load_state_dict_flexible

def load_state_dict_flexible(self, state_dict)
Load state dict with flexibility for missing or extra keys.
state_dict
dict
State dictionary to load

Notes

  • The forward() method processes single graphs, while forward_sequence() handles temporal sequences
  • TopK pooling ratio clamped to minimum 30% for stability
  • When using GRU aggregation, hidden size equals 2 * output_dim
  • Compatible with both single-graph and multi-graph temporal workflows

Build docs developers (and LLMs) love