Skip to main content

Overview

STGNN supports two pretraining strategies for the GNN encoder:
  1. Supervised pretraining (recommended) - Uses graph classification with labels
  2. Self-supervised pretraining (GraphCL-style) - Uses contrastive learning without labels
Pretraining helps the encoder learn better graph representations before temporal sequence modeling. Supervised pretraining trains the GNN encoder on graph-level classification, which is more stable and avoids representation collapse compared to contrastive methods.

Running Supervised Pretraining

python supervised_pretrain.py
This will:
  1. Load the FC graph dataset
  2. Split into train/validation (80/20 stratified split)
  3. Train a GNN encoder with a classification head for 100 epochs
  4. Save the pretrained encoder to ./model/pretrained_gnn_encoder.pth
Implementation: supervised_pretrain.py

Architecture

The supervised pretraining uses:
encoder = GraphNeuralNetwork(
    input_dim=100,           # Node feature dimension
    hidden_dim=256,          # Hidden layer size
    output_dim=256,          # Embedding dimension
    dropout=0.2,
    use_topk_pooling=True,   # Hierarchical pooling
    topk_ratio=0.3,          # Keep 30% of nodes
    layer_type="GraphSAGE",  # GNN architecture
    num_layers=2,            # Number of GNN layers
    activation="elu",        # Activation function
    use_time_features=False  # No temporal features
)

# Classification head for pretraining
classifier = nn.Sequential(
    nn.Linear(512, 128),     # 512 from TopK pooling (256*2)
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(128, 2)        # Binary classification
)
See supervised_pretrain.py:157-177 for configuration.

Training Configuration

epochs = 100
batch_size = 32
learning_rate = 1e-3
weight_decay = 1e-5

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(classifier.parameters()),
    lr=learning_rate,
    weight_decay=weight_decay
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=5,
    min_lr=1e-5
)
Implementation: supervised_pretrain.py:56-62

Training Process

def pretrain_encoder_supervised(encoder, dataset, device, 
                                epochs=50, batch_size=32, lr=1e-3):
    # Split dataset stratified by labels
    train_idx, val_idx = train_test_split(
        indices, test_size=0.2, stratify=labels, random_state=42
    )
    
    for epoch in range(epochs):
        encoder.train()
        classifier.train()
        
        for batch in train_loader:
            # Forward pass through encoder
            embeddings = encoder(batch.x, batch.edge_index, batch.batch)
            logits = classifier(embeddings)
            loss = F.cross_entropy(logits, batch.y)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                list(encoder.parameters()) + list(classifier.parameters()),
                max_norm=1.0
            )
            optimizer.step()
        
        # Validation
        val_acc = evaluate(val_loader)
        scheduler.step(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_encoder_state = encoder.state_dict().copy()
Implementation: supervised_pretrain.py:33-138

Expected Performance

Typical supervised pretraining results:
  • Training accuracy: 70-85%
  • Validation accuracy: 65-80%
  • Converges in 30-50 epochs
  • Best model selected by validation accuracy

Saved Model

The pretrained encoder is saved to:
./model/pretrained_gnn_encoder.pth
File size: ~500-800 KB (depends on architecture) The saved state dict contains only encoder weights (classifier is discarded).

Self-Supervised Pretraining (GraphCL)

GraphCL-style contrastive pretraining creates augmented views of graphs and learns representations by maximizing agreement between augmentations.

Enable in Training

python dfc_main.py \
  --pretrain_encoder \
  --pretrain_epochs 50 \
  --use_pretrained  # Load existing if available
See dfc_main.py:40-43 for arguments.

Graph Augmentation

Two augmentation strategies are implemented:

1. Drop Node

def graph_augmentation(data, aug_type="drop_node", aug_ratio=0.2):
    num_nodes = data.x.size(0)
    node_mask = torch.rand(num_nodes, device=device) > aug_ratio
    
    # Keep nodes that pass the mask
    data.x = data.x[node_mask]
    
    # Remap node indices and filter edges
    new_idx = torch.full((num_nodes,), -1, dtype=torch.long)
    new_idx[node_mask] = torch.arange(node_mask.sum())
    
    edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
    edge_index = new_idx[edge_index[:, edge_mask]]
  • Randomly drops 20% of nodes
  • Updates edge indices accordingly
  • Preserves graph structure

2. Drop Edge

def graph_augmentation(data, aug_type="drop_edge", aug_ratio=0.2):
    edge_mask = torch.rand(data.edge_index.size(1)) > aug_ratio
    data.edge_index = data.edge_index[:, edge_mask]
  • Randomly drops 20% of edges
  • Simpler than node dropping
  • Creates sparser connectivity
Implementation: dfc_main.py:92-123

Contrastive Loss (NT-Xent)

def nt_xent_loss(z1, z2, temperature=0.5):
    # Normalize embeddings
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    # Concatenate augmented views
    reps = torch.cat([z1, z2], dim=0)
    
    # Compute similarity matrix
    sim_matrix = torch.exp(torch.matmul(reps, reps.T) / temperature)
    
    # Positive pairs (same graph, different augmentations)
    pos_sim = torch.exp(torch.sum(z1 * z2, dim=1) / temperature)
    
    # Negative pairs (different graphs)
    mask = (~torch.eye(2 * batch_size, dtype=torch.bool))
    sim_sum = sim_matrix.masked_select(mask).sum(dim=1)
    
    # Contrastive loss
    loss = -torch.log(pos_sim / (sim_sum[:batch_size] + sim_sum[batch_size:]))
    return loss.mean()
  • Temperature parameter: 0.5 (controls softness of softmax)
  • Maximizes similarity between augmented views of same graph
  • Minimizes similarity between different graphs
Implementation: dfc_main.py:125-138

Pretraining Loop

def pretrain_graph_encoder(encoder, dataset, device, 
                          epochs=50, batch_size=32, lr=1e-3):
    encoder.train()
    optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        for batch in loader:
            # Create two augmented views
            batch1 = [graph_augmentation(d, "drop_node", 0.2) 
                     for d in batch.to_data_list()]
            batch2 = [graph_augmentation(d, "drop_edge", 0.2) 
                     for d in batch.to_data_list()]
            
            # Get embeddings for both views
            z1 = encoder(batch1.x, batch1.edge_index, batch1.batch)
            z2 = encoder(batch2.x, batch2.edge_index, batch2.batch)
            
            # Contrastive loss
            loss = nt_xent_loss(z1, z2)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
Implementation: dfc_main.py:141-179

Using Pretrained Encoders

Loading in Training Scripts

Both main.py and dfc_main.py automatically load pretrained encoders:
pretrained_path = os.path.join(opt.save_path, 'pretrained_gnn_encoder.pth')

if os.path.exists(pretrained_path):
    encoder.load_state_dict_flexible(torch.load(pretrained_path))
    print(f"Loaded pretrained encoder from {pretrained_path}")
else:
    print("No pretrained encoder found — training from scratch.")
See main.py:184-188 for static FC and dfc_main.py:267-283 for dynamic FC.

Freezing Encoder Weights

Optionally freeze the encoder during temporal training:
python main.py --freeze_encoder
if opt.freeze_encoder:
    for p in fold_encoder.parameters():
        p.requires_grad = False
    print("GNN encoder frozen for transfer learning.")
When frozen:
  • Only the temporal classifier (LSTM/GRU/RNN) is trained
  • Encoder serves as fixed feature extractor
  • Faster training, but may sacrifice adaptability
See main.py:230-232 and dfc_main.py:335-338.

Flexible State Dict Loading

The load_state_dict_flexible() method handles architecture mismatches:
def load_state_dict_flexible(self, state_dict):
    model_dict = self.state_dict()
    # Filter keys that match shape
    state_dict = {
        k: v for k, v in state_dict.items() 
        if k in model_dict and model_dict[k].shape == v.shape
    }
    model_dict.update(state_dict)
    self.load_state_dict(model_dict, strict=False)
This allows:
  • Loading pretrained weights into models with different heads
  • Partial weight initialization
  • Architecture modifications after pretraining
Implementation: main.py:592 (encoder method)

Pretraining Comparison

AspectSupervisedSelf-Supervised (GraphCL)
Requires labelsYesNo
StabilityHighMedium (can collapse)
Training timeFasterSlower (2x augmentations)
PerformanceBetter for classificationBetter for general features
Recommended forSTGNN (we have labels)Unlabeled graph data

Best Practices

  1. Use supervised pretraining for STGNN - it’s more stable and performs better
  2. Run pretraining once using supervised_pretrain.py, then load weights for all experiments
  3. Don’t freeze encoder unless you have very limited training data
  4. Use same architecture for pretraining and downstream tasks (hidden_dim, num_layers, etc.)
  5. Monitor validation accuracy during supervised pretraining - stop early if overfitting

Reproducibility

All pretraining methods set random seeds for reproducibility:
def set_random_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
Seeds are set before:
  • Model initialization
  • Data loading
  • Training loops
See supervised_pretrain.py:23-31 and dfc_main.py:58-66.

Implementation Files

  • supervised_pretrain.py: Standalone supervised pretraining script
  • dfc_main.py:89-179: GraphCL self-supervised pretraining
  • main.py:184-188: Loading pretrained weights for static FC
  • dfc_main.py:267-283: Loading pretrained weights for DFC
  • model.py: GraphNeuralNetwork encoder architecture

Build docs developers (and LLMs) love