Skip to main content

Overview

STGNN supports both static and dynamic functional connectivity (FC) graphs. Dynamic FC captures time-varying brain connectivity patterns within a single scan, providing richer temporal information compared to static FC which averages connectivity across the entire scan.

Static vs Dynamic FC

Static FC (Conventional Approach)

  • Computation: Single correlation matrix averaged over entire fMRI scan
  • Result: One graph per subject visit
  • Edge weights: Pearson correlation between ROI time series
  • File format: fc_matrix variable in .npz files
  • Training script: main.py
from FC_ADNIDataset import FC_ADNIDataset

dataset = FC_ADNIDataset(
    root="/path/to/data",
    var_name="fc_matrix"  # Static FC variable
)

Dynamic FC (Advanced Approach)

  • Computation: Sliding window approach creating multiple connectivity snapshots
  • Result: Sequence of graphs per subject visit (time-varying connectivity)
  • Edge weights: Time-windowed correlations capturing connectivity dynamics
  • File format: dynamic_fc variable in .npz files
  • Training script: dfc_main.py
from DFC_ADNIDataset import DFC_ADNIDataset

dataset = DFC_ADNIDataset(
    root="/path/to/data",
    var_name="dynamic_fc"  # Dynamic FC variable
)

Comparison

AspectStatic FCDynamic FC
Graphs per visit1Multiple (sliding windows)
Temporal resolutionAveragedWindow-level
Information capturedAverage connectivityTime-varying patterns
Data sizeSmallerLarger (multiple graphs/visit)
Computational costLowerHigher
Model complexitySimplerMore complex (needs temporal aggregation)
Biological realismLessMore (captures dynamics)

Dynamic FC Model Architecture

The DynamicGraphNeuralNetwork in dfc_model.py is specifically designed to handle DFC data.

Model Configuration

from dfc_model import DynamicGraphNeuralNetwork

encoder = DynamicGraphNeuralNetwork(
    input_dim=dataset.data.x.size(-1),  # Dynamic input from dataset
    hidden_dim=256,                      # Hidden layer size
    output_dim=256,                      # Embedding dimension
    num_classes=2,                       # Binary classification
    dropout=0.5,
    use_topk_pooling=True,              # Hierarchical pooling
    topk_ratio=0.3,                     # Keep 30% of nodes
    layer_type="GraphSAGE",             # GNN type
    temporal_aggregation="mean",        # How to combine windows
    num_layers=3,                       # GNN depth
    activation='relu'                    # Activation function
)
See dfc_main.py:254-265 for initialization.

Architecture Components

1. Input Projection

self.input_proj = nn.Linear(input_dim, hidden_dim)
Projects raw node features to hidden dimension.

2. GNN Layers

Supports three GNN architectures:
if layer_type == "GCN":
    conv = GCNConv(in_dim, out_dim)
elif layer_type == "GAT":
    conv = GATConv(in_dim, out_dim)
elif layer_type == "GraphSAGE":
    conv = SAGEConv(in_dim, out_dim)
Each layer followed by:
  • GraphNorm for stable training
  • Activation function (ReLU/ELU/LeakyReLU/GELU)
  • Dropout for regularization
See dfc_model.py:49-63.

3. TopK Pooling (Optional)

self.topk_pools = nn.ModuleList([
    TopKPooling(dim, ratio=safe_ratio)
    for _ in range(num_layers)
])
Hierarchically selects top 30% of nodes based on learned importance scores:
  • Reduces computational cost
  • Focuses on most relevant brain regions
  • Applied after each GNN layer
Implementation: dfc_model.py:68-75

4. Global Pooling

x_mean = global_mean_pool(x, batch)
x_max = global_max_pool(x, batch)
graph_repr = torch.cat([x_mean, x_max], dim=1)  # (batch, 2*output_dim)
Combines mean and max pooling for robust graph-level representations. See dfc_model.py:117-119.

Temporal Aggregation Strategies

DFC creates multiple graphs per visit. The model aggregates them using one of three strategies:

1. Mean Aggregation (Default)

temporal_aggregation = "mean"
out = time_outputs.mean(dim=1)
  • Average embeddings across all DFC windows
  • Most stable
  • Good for capturing overall connectivity patterns

2. Max Aggregation

temporal_aggregation = "max"
out, _ = time_outputs.max(dim=1)
  • Take maximum activation across windows
  • Emphasizes strongest connectivity patterns
  • More sensitive to outliers

3. GRU Aggregation

temporal_aggregation = "gru"
self.gru = nn.GRU(
    input_size=2*output_dim,
    hidden_size=2*output_dim,
    batch_first=True
)
_, h = self.gru(time_outputs)
out = h.squeeze(0)
  • Sequential processing of DFC windows
  • Captures temporal ordering within visit
  • Most expressive but also most complex
Implementation: dfc_model.py:137-159 Configure via command line:
python dfc_main.py --temporal_aggregation mean  # or max, gru
See dfc_main.py:54-55 for argument definition.

Training with Dynamic FC

Basic Training

python dfc_main.py \
  --n_epochs 100 \
  --batch_size 16 \
  --lr 0.001 \
  --temporal_aggregation mean \
  --layer_type GraphSAGE \
  --gnn_num_layers 2

With Advanced Features

python dfc_main.py \
  --use_time_features \
  --time_normalization log \
  --exclude_target_visit \
  --pretrain_encoder \
  --pretrain_epochs 50 \
  --use_topk_pooling \
  --topk_ratio 0.3 \
  --temporal_aggregation mean

Key Arguments

  • --temporal_aggregation: How to combine DFC windows - mean, max, or gru (default: mean)
  • --layer_type: GNN architecture - GCN, GAT, or GraphSAGE (default: GraphSAGE)
  • --gnn_hidden_dim: Hidden dimension size (default: 256)
  • --gnn_num_layers: Number of GNN layers, 2-5 (default: 2)
  • --gnn_activation: Activation function - relu, leaky_relu, elu, gelu (default: elu)
  • --use_topk_pooling: Enable hierarchical pooling (default: True)
  • --topk_ratio: Fraction of nodes to keep (default: 0.3)
See dfc_main.py:25-55 for all arguments.

DFC Dataset Structure

The DFC_ADNIDataset expects .npz files with:
# Each .npz file contains:
npz_data = np.load(file_path)

dynamic_fc = npz_data['dynamic_fc']  # (num_windows, num_rois, num_rois)
subject_id = npz_data['subject_id']  # String identifier
visit_code = npz_data.get('visit_code', 'bl')  # Visit code
label = npz_data['label']  # 0=stable, 1=converter
Multiple graphs per visit are handled automatically:
  • Each window creates a separate graph
  • Subject ID tracks all graphs from same visit
  • Temporal aggregation combines windows

Subject Graph Mapping

if not hasattr(dataset, 'subject_graph_dict'):
    mapping = {}
    for data in dataset:
        sid = getattr(data, 'subj_id', None)
        # Extract base subject ID
        if '_run' in sid:
            base_sid = sid.split('_run')[0].replace('sub-', '')
        else:
            base_sid = sid.replace('sub-', '')
        mapping.setdefault(base_sid, []).append(data)
    dataset.subject_graph_dict = mapping
This groups all visits (and DFC windows) by subject for temporal modeling. See dfc_main.py:195-209.

Forward Pass Methods

The DynamicGraphNeuralNetwork provides two forward methods:

1. Single Graph Forward

def forward(self, x, edge_index, batch, time_features=None):
    # Process a single graph (one DFC window)
    x = self.input_proj(x)
    
    for i in range(self.num_layers):
        x = self.convs[i](x, edge_index)
        x = self.gns[i](x, batch)
        x = self.activation_fn(x)
        x = self.dropout(x)
        
        if self.use_topk_pooling:
            x, edge_index, _, batch, _, _ = self.topk_pools[i](
                x, edge_index, batch=batch
            )
    
    x_mean = global_mean_pool(x, batch)
    x_max = global_max_pool(x, batch)
    return torch.cat([x_mean, x_max], dim=1)
Used by TemporalDataLoader for feature extraction. See dfc_model.py:89-121.

2. Sequence Forward

def forward_sequence(self, x_seq, edge_index_seq, batch_seq):
    # Process multiple DFC windows and classify
    time_outputs = []
    
    for t in range(len(x_seq)):
        graph_repr = self.forward(
            x_seq[t], edge_index_seq[t], batch_seq[t]
        )
        time_outputs.append(graph_repr)
    
    time_outputs = torch.stack(time_outputs, dim=1)
    
    # Apply temporal aggregation
    if self.temporal_aggregation == "mean":
        out = time_outputs.mean(dim=1)
    elif self.temporal_aggregation == "max":
        out, _ = time_outputs.max(dim=1)
    elif self.temporal_aggregation == "gru":
        _, h = self.gru(time_outputs)
        out = h.squeeze(0)
    
    # Classification
    logits = self.classifier(out)
    return logits
Used for end-to-end DFC window aggregation and classification. See dfc_model.py:123-163.

Data Preprocessing

Handling Infinite Values

# Clean up any infinite values in node features
dataset.data.x[torch.isinf(dataset.data.x)] = 0
DFC can produce extreme correlation values - replace with 0. See dfc_main.py:192.

Visit Trimming

Limit maximum visits per subject to prevent memory issues:
for subject_id in dataset.subject_graph_dict:
    visits = dataset.subject_graph_dict[subject_id]
    if len(visits) > max_visits:
        # Keep most recent visits
        dataset.subject_graph_dict[subject_id] = visits[-max_visits:]
This is important because DFC creates many graphs per visit.

Training Configuration

Optimizer

optimizer = torch.optim.Adam([
    {'params': fold_encoder.parameters(), 'lr': opt.lr * 0.5},  # Lower LR for encoder
    {'params': classifier.parameters(), 'lr': opt.lr}           # Higher LR for classifier
], weight_decay=1e-4)
Differential learning rates for stable fine-tuning. See dfc_main.py:384-387.

Scheduler

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,      # Restart every 10 epochs
    T_mult=2,    # Double period after each restart
    eta_min=1e-6 # Minimum learning rate
)
Cosine annealing with warm restarts for better convergence. See dfc_main.py:388.

Loss Function

from FocalLoss import FocalLoss

criterion = FocalLoss(
    alpha=0.75,          # Weight for minority class
    gamma=2.0,           # Focusing parameter
    label_smoothing=0.05 # Soft labels
)
Focal loss handles class imbalance (many stable, few converters). See dfc_main.py:389 and arguments at dfc_main.py:29-31.

Evaluation

Evaluation is identical to static FC:
  • Cross-validation at subject level
  • Metrics: accuracy, balanced accuracy, F1, AUC
  • Optional horizon-based analysis with time features
  • Conversion-specific accuracy tracking
See dfc_main.py:399-460 for evaluation function.

When to Use DFC vs Static FC

Use Static FC when:

  • Limited computational resources
  • Simpler, more interpretable models needed
  • Baseline comparison required
  • Data preprocessing is simpler

Use Dynamic FC when:

  • Maximum predictive performance is critical
  • Capturing brain dynamics is important
  • Sufficient computational resources available
  • Research question involves connectivity changes

Best Practices

  1. Start with static FC to establish baseline performance
  2. Use mean aggregation for DFC - it’s most stable
  3. Enable TopK pooling to reduce computational cost
  4. Use GraphSAGE as the GNN layer - works well for brain graphs
  5. Set gnn_num_layers=2 - deeper models may overfit
  6. Monitor GPU memory - DFC uses more memory than static FC
  7. Apply same temporal modeling (LSTM/GRU) as static FC for fair comparison

Implementation Files

  • dfc_main.py: Main training script for dynamic FC
  • dfc_model.py: DynamicGraphNeuralNetwork architecture
  • DFC_ADNIDataset.py: Dataset loader for DFC data
  • TemporalDataLoader.py: Batch creation (works with both static and DFC)
  • model.py: GraphNeuralNetwork for static FC (for comparison)

Build docs developers (and LLMs) love