Skip to main content

Overview

Pooling operations in the GraphNeuralNetwork encoder reduce node-level representations to fixed-size graph-level embeddings. The STGNN architecture supports two pooling strategies:
  1. TopK Pooling: Hierarchical graph coarsening that selectively retains the most important nodes
  2. Global Pooling: Direct aggregation of all node features using mean and max operations
Both strategies produce the same output dimension (output_dim × 2) but differ in computational approach and learned representations.

Global Pooling

Overview

Global pooling aggregates information from all nodes in the graph using permutation-invariant operations:
from model import GraphNeuralNetwork

encoder = GraphNeuralNetwork(
    input_dim=100,
    hidden_dim=128,
    output_dim=256,
    use_topk_pooling=False  # Disable TopK pooling
)

Implementation

The encoder combines two global pooling operations:
from torch_geometric.nn import global_mean_pool, global_max_pool

# After graph convolution layers
x_mean = global_mean_pool(x, batch)  # [B, output_dim]
x_max = global_max_pool(x, batch)    # [B, output_dim]

# Concatenate for richer representation
x = torch.cat([x_mean, x_max], dim=1)  # [B, output_dim * 2]

Mean Pooling

Averages node features across the entire graph:
x_mean = global_mean_pool(x, batch)
# Equivalent to:
# x_mean[i] = mean(x[batch == i], dim=0)
Properties:
  • Captures average node properties
  • Robust to outliers
  • Smooth representations
  • Equal weight to all nodes

Max Pooling

Takes element-wise maximum across node features:
x_max = global_max_pool(x, batch)
# Equivalent to:
# x_max[i] = max(x[batch == i], dim=0)
Properties:
  • Captures salient features
  • Emphasizes extreme values
  • Sparse activation patterns
  • Focuses on discriminative nodes

Concatenation Strategy

Combining mean and max pooling provides complementary information:
# Mean: captures overall graph structure
# Max: captures distinctive node properties
x = torch.cat([x_mean, x_max], dim=1)
Advantages:
  • Richer representations: Captures both average and extreme properties
  • Improved performance: Empirically shown to outperform single pooling methods
  • Minimal overhead: No additional parameters, just concatenation

Advantages of Global Pooling

  1. No additional parameters: Uses all nodes without learned selection
  2. Computationally efficient: Simple aggregation operations
  3. Stable training: No potential for empty graphs
  4. Interpretable: Clear semantic meaning (average and maximum)
  5. Memory efficient: No intermediate graph storage

Disadvantages

  1. Fixed receptive field: Cannot focus on important subgraphs
  2. Noise sensitivity: Includes all nodes, even noisy ones
  3. Limited hierarchical structure: Flat aggregation, no coarsening
  4. Uniform weighting: All nodes contribute equally (in mean pooling)

TopK Pooling

Overview

TopK pooling hierarchically coarsens graphs by iteratively selecting the most important nodes:
from model import GraphNeuralNetwork

encoder = GraphNeuralNetwork(
    input_dim=100,
    hidden_dim=128,
    output_dim=256,
    use_topk_pooling=True,  # Enable TopK pooling
    topk_ratio=0.5,         # Keep 50% of nodes at each layer
    num_layers=3
)

Implementation

TopK pooling is applied after each graph convolution layer:
from torch_geometric.nn import TopKPooling

for i in range(num_layers):
    # Graph convolution
    x = self.convs[i](x, edge_index)
    x = self.gns[i](x, batch)
    x = self.activation_fn(x)
    x = self.dropout(x)
    
    # TopK pooling: select top nodes based on learned scores
    x, edge_index, _, batch, perm, score = self.topk_pools[i](
        x, edge_index, batch=batch
    )

# Final global pooling on selected nodes
x_mean = global_mean_pool(x, batch)
x_max = global_max_pool(x, batch)
x = torch.cat([x_mean, x_max], dim=1)

How TopK Pooling Works

  1. Score Computation: Each node receives a score based on its features
    score = torch.tanh(linear_projection(x))  # [num_nodes, 1]
    
  2. Node Selection: Keep top k nodes with highest scores
    k = max(1, int(topk_ratio * num_nodes))
    _, perm = score.topk(k, dim=0)
    
  3. Graph Coarsening: Retain selected nodes and their connecting edges
    x = x[perm] * score[perm]  # Weight features by importance
    edge_index, _ = filter_edges(edge_index, perm)
    
  4. Batch Update: Update batch assignments for remaining nodes
    batch = batch[perm]
    

Safe Ratio Enforcement

The implementation includes safeguards to prevent empty graphs:
# From model.py lines 74-87
safe_ratio = max(0.3, min(1.0, topk_ratio))
self.topk_pools.append(TopKPooling(pool_dim, ratio=safe_ratio, min_score=None))

print(f"TopK pooling initialized with {num_layers} layers, "
      f"safe ratio: {safe_ratio} (minimum 30%)")
Rationale:
  • Minimum 30% retention ensures at least some nodes survive deep architectures
  • Prevents empty graphs that would cause forward pass failures
  • Balances selectivity with stability

Layer-wise Pooling

TopK pooling creates a hierarchy of progressively coarsened graphs:
# Example with 3 layers, 100 initial nodes, ratio=0.5 (clamped to 0.5)
Layer 0: 100 nodes → 50 nodes (50% retained)
Layer 1: 50 nodes → 25 nodes (50% retained)
Layer 2: 25 nodes → 12 nodes (50% retained, rounded down)

Advantages of TopK Pooling

  1. Hierarchical structure: Multi-scale graph representations
  2. Learned selection: Adapts to identify important nodes for the task
  3. Noise reduction: Filters out less relevant nodes
  4. Improved expressiveness: Can focus on discriminative subgraphs
  5. Attention-like mechanism: Weights nodes by importance

Disadvantages

  1. Additional parameters: Each TopK layer has learnable projection weights
  2. Computational overhead: Score computation and graph filtering
  3. Training instability: Risk of empty graphs without safeguards
  4. Discrete operation: Non-differentiable node selection (straight-through gradients)
  5. Memory overhead: Stores intermediate graphs

Comparison

Performance

MetricGlobal PoolingTopK Pooling
Parameters0 additionalnum_layers × output_dim
ComputationO(N)O(N log N) per layer
MemoryLowMedium
Training stabilityHighMedium (with safeguards)
ExpressivenessGoodExcellent

When to Use Global Pooling

  • Small graphs: When nodes < 50, selection overhead dominates
  • Limited data: Fewer parameters reduce overfitting risk
  • Uniform importance: When all nodes contribute equally
  • Baseline models: Quick prototyping and experimentation
  • Interpretability: Simple, well-understood aggregation

When to Use TopK Pooling

  • Large graphs: When nodes > 100, can focus on important regions
  • Noisy data: Filter out irrelevant or low-quality nodes
  • Hierarchical structure: Exploit multi-scale patterns
  • Performance critical: When model capacity is more important than speed
  • Sufficient data: Large datasets can support additional parameters

Configuration Examples

Conservative TopK (High Retention)

encoder = GraphNeuralNetwork(
    input_dim=100,
    hidden_dim=128,
    output_dim=256,
    use_topk_pooling=True,
    topk_ratio=0.8,  # Keep 80% of nodes
    num_layers=3
)
# Effective retention: 0.8 × 0.8 × 0.8 = 51.2% of original nodes
Use case: Preserve most information, gentle coarsening

Aggressive TopK (Low Retention)

encoder = GraphNeuralNetwork(
    input_dim=100,
    hidden_dim=128,
    output_dim=256,
    use_topk_pooling=True,
    topk_ratio=0.3,  # Keep 30% of nodes (minimum allowed)
    num_layers=3
)
# Effective retention: 0.3 × 0.3 × 0.3 = 2.7% of original nodes
Use case: Strong sparsification, focus on most salient nodes

Balanced Configuration

encoder = GraphNeuralNetwork(
    input_dim=100,
    hidden_dim=128,
    output_dim=256,
    use_topk_pooling=True,
    topk_ratio=0.5,  # Keep 50% of nodes
    num_layers=3
)
# Effective retention: 0.5 × 0.5 × 0.5 = 12.5% of original nodes
Use case: Standard configuration, good balance of selectivity and stability

Global Pooling (No Coarsening)

encoder = GraphNeuralNetwork(
    input_dim=100,
    hidden_dim=128,
    output_dim=256,
    use_topk_pooling=False,  # Use global pooling only
    num_layers=3
)
# All nodes contribute to final representation
Use case: Simple baseline, maximum stability

Implementation Details

Forward Pass with TopK

From model.py lines 136-155:
def _forward_with_topk(self, x, edge_index, batch):
    # Dynamic forward pass through all layers with TopK pooling
    for i in range(self.num_layers):
        # Apply convolution
        x = self.convs[i](x, edge_index)
        x = self.gns[i](x, batch)
        x = self.activation_fn(x)
        x = self.dropout(x)
        
        # Apply TopK pooling after each layer
        x, edge_index, _, batch, perm, score = self.topk_pools[i](
            x, edge_index, batch=batch
        )
    
    # Global pooling on the remaining top nodes
    x_mean = self.pool_mean(x, batch)
    x_max = self.pool_max(x, batch)
    
    # Combine pooled representations
    x = torch.cat([x_mean, x_max], dim=1)
    
    return x

Forward Pass without TopK

From model.py lines 157-171:
def _forward_traditional(self, x, edge_index, batch):
    # Dynamic forward pass through all layers without TopK pooling
    for i in range(self.num_layers):
        # Apply convolution
        x = self.convs[i](x, edge_index)
        x = self.gns[i](x, batch)
        x = self.activation_fn(x)
        x = self.dropout(x)
    
    # Combine mean and max pooling for richer representation
    x_mean = self.pool_mean(x, batch)
    x_max = self.pool_max(x, batch)
    x = torch.cat([x_mean, x_max], dim=1)
    
    return x

Experimental Recommendations

# Grid search over pooling strategies
configs = [
    {"use_topk_pooling": False},  # Baseline
    {"use_topk_pooling": True, "topk_ratio": 0.3},
    {"use_topk_pooling": True, "topk_ratio": 0.5},
    {"use_topk_pooling": True, "topk_ratio": 0.7},
]

for config in configs:
    model = GraphNeuralNetwork(
        input_dim=100,
        hidden_dim=128,
        output_dim=256,
        num_layers=3,
        **config
    )
    # Train and evaluate...

Ablation Study

# Test individual pooling operations
for pool_type in ['mean', 'max', 'concat']:
    if pool_type == 'mean':
        x = global_mean_pool(x, batch)
    elif pool_type == 'max':
        x = global_max_pool(x, batch)
    else:  # concat
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)

Visualization

Node Retention Across Layers

import matplotlib.pyplot as plt
import numpy as np

def plot_node_retention(num_nodes=100, topk_ratio=0.5, num_layers=3):
    nodes = [num_nodes]
    for _ in range(num_layers):
        safe_ratio = max(0.3, topk_ratio)
        nodes.append(int(nodes[-1] * safe_ratio))
    
    plt.plot(range(len(nodes)), nodes, marker='o')
    plt.xlabel('Layer')
    plt.ylabel('Number of Nodes')
    plt.title(f'TopK Pooling: Node Retention (ratio={topk_ratio})')
    plt.grid(True)
    plt.show()

plot_node_retention()

See Also

Build docs developers (and LLMs) love