Usage
Run training with:Training Parameters
Basic Training
Number of training epochs. Training continues until either this limit is reached or early stopping triggers (20 epochs patience).Location:
main.py:25Batch size for temporal sequences (number of subjects per batch). Each subject may have multiple visits, so the actual number of graphs processed per batch is higher.Location:
main.py:26Note: Memory usage scales with batch size × max visits per subject.Learning rate for Adam optimizer. Applied to both encoder and classifier parameters.Location:
main.py:27Scheduler: Automatically reduced by 50% after 10 epochs without validation improvement (min: 1e-6).Directory path for saving trained models. Best model for each fold is saved as
best_model_fold{N}.pth.Location: main.py:31Whether to save the best model checkpoint for each fold.Location:
main.py:32Loss Function Parameters
Focal Loss
Weight for the minority class (converters) in Focal Loss. Range: [0, 1].
alpha = 0.90: 90% weight to converters (minority)alpha = 0.50: Equal weight to both classes- Higher values increase focus on minority class
main.py:28Implementation: See FocalLoss.py:9-29Focusing parameter for Focal Loss. Higher values down-weight easy examples more aggressively.
gamma = 0: Equivalent to standard cross-entropygamma = 2: Original Focal Loss paper recommendationgamma = 3: Stronger focus on hard examples (current default)
main.py:29Formula: Loss = -α(1-p_t)^γ log(p_t)Label smoothing factor applied before computing Focal Loss. Prevents overconfident predictions.
0.0: No smoothing (hard labels)0.05: 5% smoothing (current default)0.1: 10% smoothing (higher regularization)
main.py:30Minority Class Forcing
Number of initial epochs to apply minority class forcing loss. This additional loss term encourages the model to predict the converter class during early training.Location:
main.py:33Implementation: Forcing weight decays linearly from 0.1 to 0 over these epochs (main.py:465-482).Temporal Model Parameters
RNN Architecture
Type of recurrent neural network for temporal modeling.Options:
LSTM: Long Short-Term Memory (default, best for long sequences)GRU: Gated Recurrent Unit (fewer parameters, faster training)RNN: Vanilla RNN (simplest, may have gradient issues)
main.py:38Models: TemporalPredictor.py, GRUPredictor.py, RNNPredictor.pyHidden dimension for LSTM/GRU/RNN layers. Controls the capacity of temporal modeling.Location:
main.py:34Note: Final classification input is hidden_dim * 2 for bidirectional LSTM, otherwise hidden_dim.Number of stacked LSTM/GRU/RNN layers.Location:
main.py:35Note: Dropout between layers is automatically applied when num_layers > 1 (dropout=0.45).Use bidirectional LSTM (processes sequences forward and backward).Location:
main.py:36Note: Only supported for LSTM model type. GRU and RNN are always unidirectional.Sequence Processing
Maximum number of visits to include per subject. Sequences longer than this are trimmed to keep only the most recent visits.Location:
main.py:39Trimming logic: main.py:110-115Exclude the target visit from input sequences to prevent data leakage. When enabled:
- Multi-visit subjects: Use visits 1 to N-1 as input, predict visit N
- Single-visit subjects: Use current visit, predict future state
main.py:48Action: store_true (include flag to enable)Graph Neural Network Parameters
Architecture
Type of graph convolutional layer.Options:
GraphSAGE: Inductive learning via sampling and aggregating (default)GCN: Graph Convolutional Network (spectral approach)GAT: Graph Attention Network (learned edge importance)
main.py:43Implementation: model.py:56-63Hidden dimension for GNN layers (intermediate layer sizes).Location:
main.py:44Architecture: 100 → 256 → 256 → 256 (output)Number of GNN convolutional layers. Range: 2-5 recommended.Location:
main.py:45Note: Each layer is followed by GraphNorm, activation, and dropout.Activation function for GNN layers.Options:
elu: Exponential Linear Unit (default, smooth negatives)relu: Rectified Linear Unit (standard)leaky_relu: Leaky ReLU (prevents dead neurons)gelu: Gaussian Error Linear Unit (smooth, used in transformers)
main.py:46Implementation: model.py:27-34Pooling
Use TopK pooling instead of global mean/max pooling. TopK selects the most important nodes at each layer.Location:
main.py:41Action: store_true (enabled by default)Fraction of nodes to keep in TopK pooling. Range: (0, 1].Location:
main.py:42Safeguard: Automatically clamped to minimum 0.3 to prevent empty graphs (main.py:66-68, model.py:74).Transfer Learning
Freeze GNN encoder weights during temporal training. Use when you have a pretrained encoder.Location:
main.py:40Action: store_true (include flag to enable)Pretrained path: {save_path}/pretrained_gnn_encoder.pth (created via supervised_pretrain.py)Temporal Features
Enable time-aware prediction by incorporating temporal gaps into the model.Location:
main.py:47Action: store_true (include flag to enable)Effect: Projects time-to-predict into 32D embedding and fuses with graph features (512D + 32D → 512D).Method for normalizing temporal gaps before feeding to the model.Options:
log: log(1 + months/12) - compresses long time gaps (default)minmax: Linear scaling to [0, 1]buckets: Discretize into time binsraw: No normalization (use months directly)
main.py:49Implementation: temporal_gap_processor.pyDefault prediction horizon in months for subjects with only one visit.Location:
main.py:50Fallback: Used when months_to_next field is missing or invalid.Cross-Validation
Number of cross-validation folds. Set to 1 for single train/val/test split.Location:
main.py:37Split ratio: 80% train+val (further split 80/20 → 64% train, 16% val), 20% testStratification: Ensures balanced converter/stable distribution across folds.Example Configurations
Default Training
High-Capacity Model
GAT with Time Features
Transfer Learning from Pretrained
Strong Class Imbalance Handling
All boolean arguments use
store_true action. Include the flag to enable (e.g., --freeze_encoder), omit to disable.