Overview
STGNN has numerous hyperparameters that significantly impact model performance. While the codebase does not currently include built-in Optuna integration, this guide documents the key hyperparameters, their search spaces, and recommended tuning strategies based on the implementation.Key Hyperparameters
Model Architecture
GNN Configuration
main.py:43-46 and dfc_main.py:46-50.
Recommended search space:
| Parameter | Type | Range | Default | Impact |
|---|---|---|---|---|
gnn_hidden_dim | int | [128, 256, 512] | 256 | High - affects capacity |
gnn_num_layers | int | [2, 3, 4, 5] | 2 | Medium - deeper can overfit |
layer_type | categorical | [“GCN”, “GAT”, “GraphSAGE”] | GraphSAGE | High - architecture choice |
gnn_activation | categorical | [“relu”, “elu”, “leaky_relu”, “gelu”] | elu | Low-Medium |
topk_ratio | float | [0.2, 0.3, 0.4, 0.5] | 0.3 | Medium - sparsity level |
Temporal Model Configuration
main.py:34-37 and dfc_main.py:34-38.
Recommended search space:
| Parameter | Type | Range | Default | Impact |
|---|---|---|---|---|
lstm_hidden_dim | int | [32, 64, 128, 256] | 64 | High - temporal capacity |
lstm_num_layers | int | [1, 2] | 1 | Medium - more layers → overfitting |
lstm_bidirectional | boolean | [True, False] | True | High - captures future context |
model_type | categorical | [“LSTM”, “GRU”, “RNN”] | LSTM | High - architecture choice |
Training Configuration
Learning Rate and Optimization
main.py:26-27 and dfc_main.py:26-27.
Recommended search space:
| Parameter | Type | Range | Default | Impact |
|---|---|---|---|---|
lr | float | [1e-4, 5e-4, 1e-3, 5e-3] | 1e-3 | High - convergence speed |
batch_size | int | [8, 16, 32] | 16 | Medium - stability vs speed |
n_epochs | int | 100-200 | 100 | Low - use early stopping |
Loss Function Parameters
main.py:28-31 and dfc_main.py:28-34.
Recommended search space:
| Parameter | Type | Range | Default | Impact |
|---|---|---|---|---|
focal_alpha | float | [0.70, 0.75, 0.80, 0.85, 0.90] | 0.90 | High - minority class weight |
focal_gamma | float | [1.0, 2.0, 3.0, 4.0] | 3.0 | Medium - focusing strength |
label_smoothing | float | [0.0, 0.05, 0.1] | 0.05 | Low - regularization |
minority_focus_epochs | int | [10, 20, 30] | 20 | Medium - forcing duration |
Temporal Features
main.py:47-50 and dfc_main.py:50-53.
Recommended search space:
| Parameter | Type | Range | Default | Impact |
|---|---|---|---|---|
use_time_features | boolean | [True, False] | False | High - temporal awareness |
time_normalization | categorical | [“log”, “minmax”, “buckets”] | log | Medium - time encoding |
exclude_target_visit | boolean | [True, False] | False | High - prevents leakage |
single_visit_horizon | int | [3, 6, 12] | 6 | Low - single-visit handling |
Pretraining
dfc_main.py:40-43.
Recommended search space:
| Parameter | Type | Range | Default | Impact |
|---|---|---|---|---|
freeze_encoder | boolean | [True, False] | False | High - transfer learning |
pretrain_epochs | int | [30, 50, 100] | 50 | Low - diminishing returns |
Manual Tuning Strategy
Stage 1: Architecture Selection (Priority: High)
-
GNN Layer Type
- GraphSAGE typically works best for brain graphs
- GAT adds attention but increases complexity
-
Temporal Model Type
- LSTM: Best for long sequences
- GRU: Faster, similar performance
- RNN: Simplest, may underfit
-
Bidirectional LSTM
- Bidirectional usually better but doubles parameters
Stage 2: Capacity Tuning (Priority: High)
-
Hidden Dimensions
-
Network Depth
- Start with 2 layers (default)
- Increase if underfitting, decrease if overfitting
Stage 3: Loss Function Tuning (Priority: High for Imbalanced Data)
- High
alpha(e.g., 0.90): More weight to minority class - High
gamma(e.g., 3.0): More focus on hard examples - Adjust based on class distribution
Stage 4: Learning Rate and Batch Size (Priority: Medium)
- Larger batch → faster but less noise (may converge to poor local minimum)
- Smaller learning rate → more stable but slower
- Use learning rate scheduler (already implemented)
Stage 5: Regularization (Priority: Medium)
-
TopK Pooling Ratio
- Lower ratio → more aggressive pruning → stronger regularization
-
Dropout (Fixed in Code)
- Could expose as hyperparameter for tuning
Stage 6: Temporal Features (Priority: High if Using Time)
Implementing Optuna Integration
While not currently in the codebase, here’s how to add Optuna:Installation
Example Optuna Objective Function
Pruning for Early Stopping
Visualization
Grid Search Alternative
For systematic exploration without Optuna:Random Search Strategy
Cross-Validation Considerations
The codebase uses 5-fold cross-validation:main.py:38.
For hyperparameter tuning:
- Use inner CV: 5 outer folds for evaluation, 4 inner folds for tuning
- Or: Fix hyperparameters on fold 1, then evaluate on all 5 folds
- Report mean ± std across folds for final model
Tracking Experiments
Use weights & biases (wandb) or MLflow:Best Practices
- Start with architecture (layer_type, model_type) - highest impact
- Tune capacity next (hidden_dims, num_layers) - prevents under/overfitting
- Then optimize training (lr, batch_size, focal loss params)
- Finally regularization (topk_ratio, dropout)
- Use validation AUC as primary metric (handles class imbalance)
- Report test metrics only for final model (avoid overfitting to test set)
- Track all experiments - even failed ones provide information
- Set random seeds for reproducibility (already implemented)
- Use early stopping (patience=20 in code)
- Consider computational budget - each full run takes hours
Computational Estimates
Single fold training:- Static FC: ~30-60 minutes on GPU
- Dynamic FC: ~60-120 minutes on GPU
- Static FC: ~2.5-5 hours
- Dynamic FC: ~5-10 hours
- 50 trials × 5 folds = ~10-25 days for DFC
- Use pruning or parallel workers
Implementation Files
main.py:24-51: Command-line arguments for static FCdfc_main.py:25-56: Command-line arguments for dynamic FCsupervised_pretrain.py:157-177: Pretraining hyperparametersmain.py:294-309: Optimizer and loss configurationdfc_main.py:384-389: DFC optimizer and loss configuration