Skip to main content

Overview

STGNN has numerous hyperparameters that significantly impact model performance. While the codebase does not currently include built-in Optuna integration, this guide documents the key hyperparameters, their search spaces, and recommended tuning strategies based on the implementation.

Key Hyperparameters

Model Architecture

GNN Configuration

# From main.py and dfc_main.py
parser.add_argument('--gnn_hidden_dim', type=int, default=256)
parser.add_argument('--gnn_num_layers', type=int, default=2)
parser.add_argument('--layer_type', type=str, default="GraphSAGE")
parser.add_argument('--gnn_activation', type=str, default='elu')
parser.add_argument('--use_topk_pooling', action='store_true', default=True)
parser.add_argument('--topk_ratio', type=float, default=0.3)
See main.py:43-46 and dfc_main.py:46-50. Recommended search space:
ParameterTypeRangeDefaultImpact
gnn_hidden_dimint[128, 256, 512]256High - affects capacity
gnn_num_layersint[2, 3, 4, 5]2Medium - deeper can overfit
layer_typecategorical[“GCN”, “GAT”, “GraphSAGE”]GraphSAGEHigh - architecture choice
gnn_activationcategorical[“relu”, “elu”, “leaky_relu”, “gelu”]eluLow-Medium
topk_ratiofloat[0.2, 0.3, 0.4, 0.5]0.3Medium - sparsity level

Temporal Model Configuration

# LSTM/GRU/RNN hyperparameters
parser.add_argument('--lstm_hidden_dim', type=int, default=64)
parser.add_argument('--lstm_num_layers', type=int, default=1)
parser.add_argument('--lstm_bidirectional', type=bool, default=True)
parser.add_argument('--model_type', type=str, default='LSTM')
See main.py:34-37 and dfc_main.py:34-38. Recommended search space:
ParameterTypeRangeDefaultImpact
lstm_hidden_dimint[32, 64, 128, 256]64High - temporal capacity
lstm_num_layersint[1, 2]1Medium - more layers → overfitting
lstm_bidirectionalboolean[True, False]TrueHigh - captures future context
model_typecategorical[“LSTM”, “GRU”, “RNN”]LSTMHigh - architecture choice

Training Configuration

Learning Rate and Optimization

parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--n_epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=16)
See main.py:26-27 and dfc_main.py:26-27. Recommended search space:
ParameterTypeRangeDefaultImpact
lrfloat[1e-4, 5e-4, 1e-3, 5e-3]1e-3High - convergence speed
batch_sizeint[8, 16, 32]16Medium - stability vs speed
n_epochsint100-200100Low - use early stopping
Optimizer configuration (fixed in code):
# From main.py:294-297
optimizer = torch.optim.Adam([
    {'params': encoder_params, 'lr': opt.lr},
    {'params': classifier_params, 'lr': opt.lr}
], betas=(0.9, 0.999), weight_decay=1e-4)

Loss Function Parameters

# Focal loss for handling class imbalance
parser.add_argument('--focal_alpha', type=float, default=0.90)
parser.add_argument('--focal_gamma', type=float, default=3.0)
parser.add_argument('--label_smoothing', type=float, default=0.05)
parser.add_argument('--minority_focus_epochs', type=int, default=20)
See main.py:28-31 and dfc_main.py:28-34. Recommended search space:
ParameterTypeRangeDefaultImpact
focal_alphafloat[0.70, 0.75, 0.80, 0.85, 0.90]0.90High - minority class weight
focal_gammafloat[1.0, 2.0, 3.0, 4.0]3.0Medium - focusing strength
label_smoothingfloat[0.0, 0.05, 0.1]0.05Low - regularization
minority_focus_epochsint[10, 20, 30]20Medium - forcing duration

Temporal Features

parser.add_argument('--use_time_features', action='store_true')
parser.add_argument('--time_normalization', type=str, default='log')
parser.add_argument('--exclude_target_visit', action='store_true')
parser.add_argument('--single_visit_horizon', type=int, default=6)
See main.py:47-50 and dfc_main.py:50-53. Recommended search space:
ParameterTypeRangeDefaultImpact
use_time_featuresboolean[True, False]FalseHigh - temporal awareness
time_normalizationcategorical[“log”, “minmax”, “buckets”]logMedium - time encoding
exclude_target_visitboolean[True, False]FalseHigh - prevents leakage
single_visit_horizonint[3, 6, 12]6Low - single-visit handling

Pretraining

# For DFC model (dfc_main.py)
parser.add_argument('--pretrain_encoder', action='store_true')
parser.add_argument('--pretrain_epochs', type=int, default=50)
parser.add_argument('--freeze_encoder', action='store_true')
parser.add_argument('--use_pretrained', action='store_true')
See dfc_main.py:40-43. Recommended search space:
ParameterTypeRangeDefaultImpact
freeze_encoderboolean[True, False]FalseHigh - transfer learning
pretrain_epochsint[30, 50, 100]50Low - diminishing returns

Manual Tuning Strategy

Stage 1: Architecture Selection (Priority: High)

  1. GNN Layer Type
    for layer in GCN GAT GraphSAGE; do
        python main.py --layer_type $layer --gnn_num_layers 2
    done
    
    • GraphSAGE typically works best for brain graphs
    • GAT adds attention but increases complexity
  2. Temporal Model Type
    for model in LSTM GRU RNN; do
        python main.py --model_type $model
    done
    
    • LSTM: Best for long sequences
    • GRU: Faster, similar performance
    • RNN: Simplest, may underfit
  3. Bidirectional LSTM
    python main.py --lstm_bidirectional True
    python main.py --lstm_bidirectional False
    
    • Bidirectional usually better but doubles parameters

Stage 2: Capacity Tuning (Priority: High)

  1. Hidden Dimensions
    for gnn_dim in 128 256 512; do
        for lstm_dim in 32 64 128; do
            python main.py --gnn_hidden_dim $gnn_dim --lstm_hidden_dim $lstm_dim
        done
    done
    
  2. Network Depth
    for layers in 2 3 4; do
        python main.py --gnn_num_layers $layers
    done
    
    • Start with 2 layers (default)
    • Increase if underfitting, decrease if overfitting

Stage 3: Loss Function Tuning (Priority: High for Imbalanced Data)

for alpha in 0.70 0.75 0.80 0.85 0.90; do
    for gamma in 1.0 2.0 3.0; do
        python main.py --focal_alpha $alpha --focal_gamma $gamma
    done
done
Interpretation:
  • High alpha (e.g., 0.90): More weight to minority class
  • High gamma (e.g., 3.0): More focus on hard examples
  • Adjust based on class distribution

Stage 4: Learning Rate and Batch Size (Priority: Medium)

for lr in 0.0001 0.0005 0.001 0.005; do
    for bs in 8 16 32; do
        python main.py --lr $lr --batch_size $bs
    done
done
Guidelines:
  • Larger batch → faster but less noise (may converge to poor local minimum)
  • Smaller learning rate → more stable but slower
  • Use learning rate scheduler (already implemented)

Stage 5: Regularization (Priority: Medium)

  1. TopK Pooling Ratio
    for ratio in 0.2 0.3 0.4 0.5; do
        python main.py --topk_ratio $ratio
    done
    
    • Lower ratio → more aggressive pruning → stronger regularization
  2. Dropout (Fixed in Code)
    # GNN dropout: 0.2 (supervised_pretrain.py:174)
    # Temporal classifier dropout: 0.45 (main.py:262)
    
    • Could expose as hyperparameter for tuning

Stage 6: Temporal Features (Priority: High if Using Time)

# Without time features
python main.py --exclude_target_visit

# With time features + different normalizations
for norm in log minmax buckets; do
    python main.py --use_time_features --time_normalization $norm --exclude_target_visit
done

Implementing Optuna Integration

While not currently in the codebase, here’s how to add Optuna:

Installation

pip install optuna

Example Optuna Objective Function

import optuna
from optuna.trial import Trial

def objective(trial: Trial) -> float:
    # Suggest hyperparameters
    lr = trial.suggest_float('lr', 1e-4, 5e-3, log=True)
    batch_size = trial.suggest_categorical('batch_size', [8, 16, 32])
    gnn_hidden_dim = trial.suggest_categorical('gnn_hidden_dim', [128, 256, 512])
    lstm_hidden_dim = trial.suggest_categorical('lstm_hidden_dim', [32, 64, 128])
    layer_type = trial.suggest_categorical('layer_type', ['GCN', 'GAT', 'GraphSAGE'])
    gnn_num_layers = trial.suggest_int('gnn_num_layers', 2, 4)
    topk_ratio = trial.suggest_float('topk_ratio', 0.2, 0.5)
    focal_alpha = trial.suggest_float('focal_alpha', 0.7, 0.95)
    focal_gamma = trial.suggest_float('focal_gamma', 1.0, 4.0)
    
    # Update opt with trial suggestions
    opt.lr = lr
    opt.batch_size = batch_size
    opt.gnn_hidden_dim = gnn_hidden_dim
    opt.lstm_hidden_dim = lstm_hidden_dim
    opt.layer_type = layer_type
    opt.gnn_num_layers = gnn_num_layers
    opt.topk_ratio = topk_ratio
    opt.focal_alpha = focal_alpha
    opt.focal_gamma = focal_gamma
    
    # Run training and get validation metric
    # (extract training loop from main.py into a function)
    val_auc = train_and_evaluate(opt)
    
    return val_auc  # Maximize AUC

# Create study
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)

print('Best hyperparameters:', study.best_params)
print('Best AUC:', study.best_value)

Pruning for Early Stopping

import optuna
from optuna.pruners import MedianPruner

# Create study with pruner
study = optuna.create_study(
    direction='maximize',
    pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=10)
)

def objective(trial):
    # ... hyperparameter suggestions ...
    
    for epoch in range(opt.n_epochs):
        # ... training ...
        val_auc = evaluate(val_loader)
        
        # Report and prune
        trial.report(val_auc, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()
    
    return val_auc

study.optimize(objective, n_trials=100)

Visualization

import optuna.visualization as vis

# Optimization history
fig = vis.plot_optimization_history(study)
fig.show()

# Parameter importances
fig = vis.plot_param_importances(study)
fig.show()

# Parallel coordinate plot
fig = vis.plot_parallel_coordinate(study)
fig.show()

Grid Search Alternative

For systematic exploration without Optuna:
#!/bin/bash
# grid_search.sh

for layer in GraphSAGE GAT; do
    for gnn_dim in 256 512; do
        for lstm_dim in 64 128; do
            for lr in 0.0005 0.001; do
                for alpha in 0.85 0.90; do
                    python main.py \
                        --layer_type $layer \
                        --gnn_hidden_dim $gnn_dim \
                        --lstm_hidden_dim $lstm_dim \
                        --lr $lr \
                        --focal_alpha $alpha \
                        --save_path ./models/grid_${layer}_${gnn_dim}_${lstm_dim}_${lr}_${alpha}
                done
            done
        done
    done
done

Random Search Strategy

import random
import subprocess

n_trials = 50

for trial in range(n_trials):
    # Sample hyperparameters
    lr = 10 ** random.uniform(-4, -2.3)  # 1e-4 to 5e-3
    batch_size = random.choice([8, 16, 32])
    gnn_hidden_dim = random.choice([128, 256, 512])
    lstm_hidden_dim = random.choice([32, 64, 128])
    layer_type = random.choice(['GCN', 'GAT', 'GraphSAGE'])
    gnn_num_layers = random.randint(2, 4)
    topk_ratio = random.uniform(0.2, 0.5)
    focal_alpha = random.uniform(0.7, 0.95)
    focal_gamma = random.uniform(1.0, 4.0)
    
    # Run training
    cmd = f"""python main.py \
        --lr {lr} \
        --batch_size {batch_size} \
        --gnn_hidden_dim {gnn_hidden_dim} \
        --lstm_hidden_dim {lstm_hidden_dim} \
        --layer_type {layer_type} \
        --gnn_num_layers {gnn_num_layers} \
        --topk_ratio {topk_ratio} \
        --focal_alpha {focal_alpha} \
        --focal_gamma {focal_gamma} \
        --save_path ./models/random_trial_{trial}
    """
    
    subprocess.run(cmd, shell=True)

Cross-Validation Considerations

The codebase uses 5-fold cross-validation:
parser.add_argument('--num_folds', type=int, default=5)
See main.py:38. For hyperparameter tuning:
  • Use inner CV: 5 outer folds for evaluation, 4 inner folds for tuning
  • Or: Fix hyperparameters on fold 1, then evaluate on all 5 folds
  • Report mean ± std across folds for final model

Tracking Experiments

Use weights & biases (wandb) or MLflow:
import wandb

wandb.init(project="stgnn-tuning", config=opt)

for epoch in range(opt.n_epochs):
    # ... training ...
    wandb.log({
        'train_loss': avg_loss,
        'train_acc': train_acc,
        'val_acc': val_results['accuracy'],
        'val_auc': val_results['auc']
    })

wandb.log({'test_auc': test_results['auc']})

Best Practices

  1. Start with architecture (layer_type, model_type) - highest impact
  2. Tune capacity next (hidden_dims, num_layers) - prevents under/overfitting
  3. Then optimize training (lr, batch_size, focal loss params)
  4. Finally regularization (topk_ratio, dropout)
  5. Use validation AUC as primary metric (handles class imbalance)
  6. Report test metrics only for final model (avoid overfitting to test set)
  7. Track all experiments - even failed ones provide information
  8. Set random seeds for reproducibility (already implemented)
  9. Use early stopping (patience=20 in code)
  10. Consider computational budget - each full run takes hours

Computational Estimates

Single fold training:
  • Static FC: ~30-60 minutes on GPU
  • Dynamic FC: ~60-120 minutes on GPU
Full 5-fold CV:
  • Static FC: ~2.5-5 hours
  • Dynamic FC: ~5-10 hours
Hyperparameter search:
  • 50 trials × 5 folds = ~10-25 days for DFC
  • Use pruning or parallel workers

Implementation Files

  • main.py:24-51: Command-line arguments for static FC
  • dfc_main.py:25-56: Command-line arguments for dynamic FC
  • supervised_pretrain.py:157-177: Pretraining hyperparameters
  • main.py:294-309: Optimizer and loss configuration
  • dfc_main.py:384-389: DFC optimizer and loss configuration

Build docs developers (and LLMs) love