Skip to main content

Usage

Run training with:
python main.py [OPTIONS]

Training Parameters

Basic Training

--n_epochs
int
default:"100"
Number of training epochs. Training continues until either this limit is reached or early stopping triggers (20 epochs patience).Location: main.py:25
--batch_size
int
default:"16"
Batch size for temporal sequences (number of subjects per batch). Each subject may have multiple visits, so the actual number of graphs processed per batch is higher.Location: main.py:26Note: Memory usage scales with batch size × max visits per subject.
--lr
float
default:"0.001"
Learning rate for Adam optimizer. Applied to both encoder and classifier parameters.Location: main.py:27Scheduler: Automatically reduced by 50% after 10 epochs without validation improvement (min: 1e-6).
--save_path
str
default:"./model/"
Directory path for saving trained models. Best model for each fold is saved as best_model_fold{N}.pth.Location: main.py:31
--save_model
bool
default:"true"
Whether to save the best model checkpoint for each fold.Location: main.py:32

Loss Function Parameters

Focal Loss

--focal_alpha
float
default:"0.90"
Weight for the minority class (converters) in Focal Loss. Range: [0, 1].
  • alpha = 0.90: 90% weight to converters (minority)
  • alpha = 0.50: Equal weight to both classes
  • Higher values increase focus on minority class
Location: main.py:28Implementation: See FocalLoss.py:9-29
--focal_gamma
float
default:"3.0"
Focusing parameter for Focal Loss. Higher values down-weight easy examples more aggressively.
  • gamma = 0: Equivalent to standard cross-entropy
  • gamma = 2: Original Focal Loss paper recommendation
  • gamma = 3: Stronger focus on hard examples (current default)
Location: main.py:29Formula: Loss = -α(1-p_t)^γ log(p_t)
--label_smoothing
float
default:"0.05"
Label smoothing factor applied before computing Focal Loss. Prevents overconfident predictions.
  • 0.0: No smoothing (hard labels)
  • 0.05: 5% smoothing (current default)
  • 0.1: 10% smoothing (higher regularization)
Location: main.py:30

Minority Class Forcing

--minority_focus_epochs
int
default:"20"
Number of initial epochs to apply minority class forcing loss. This additional loss term encourages the model to predict the converter class during early training.Location: main.py:33Implementation: Forcing weight decays linearly from 0.1 to 0 over these epochs (main.py:465-482).

Temporal Model Parameters

RNN Architecture

--model_type
str
default:"LSTM"
Type of recurrent neural network for temporal modeling.Options:
  • LSTM: Long Short-Term Memory (default, best for long sequences)
  • GRU: Gated Recurrent Unit (fewer parameters, faster training)
  • RNN: Vanilla RNN (simplest, may have gradient issues)
Location: main.py:38Models: TemporalPredictor.py, GRUPredictor.py, RNNPredictor.py
--lstm_hidden_dim
int
default:"64"
Hidden dimension for LSTM/GRU/RNN layers. Controls the capacity of temporal modeling.Location: main.py:34Note: Final classification input is hidden_dim * 2 for bidirectional LSTM, otherwise hidden_dim.
--lstm_num_layers
int
default:"1"
Number of stacked LSTM/GRU/RNN layers.Location: main.py:35Note: Dropout between layers is automatically applied when num_layers > 1 (dropout=0.45).
--lstm_bidirectional
bool
default:"true"
Use bidirectional LSTM (processes sequences forward and backward).Location: main.py:36Note: Only supported for LSTM model type. GRU and RNN are always unidirectional.

Sequence Processing

--max_visits
int
default:"10"
Maximum number of visits to include per subject. Sequences longer than this are trimmed to keep only the most recent visits.Location: main.py:39Trimming logic: main.py:110-115
--exclude_target_visit
bool
default:"false"
Exclude the target visit from input sequences to prevent data leakage. When enabled:
  • Multi-visit subjects: Use visits 1 to N-1 as input, predict visit N
  • Single-visit subjects: Use current visit, predict future state
Location: main.py:48Action: store_true (include flag to enable)

Graph Neural Network Parameters

Architecture

--layer_type
str
default:"GraphSAGE"
Type of graph convolutional layer.Options:
  • GraphSAGE: Inductive learning via sampling and aggregating (default)
  • GCN: Graph Convolutional Network (spectral approach)
  • GAT: Graph Attention Network (learned edge importance)
Location: main.py:43Implementation: model.py:56-63
--gnn_hidden_dim
int
default:"256"
Hidden dimension for GNN layers (intermediate layer sizes).Location: main.py:44Architecture: 100 → 256 → 256 → 256 (output)
--gnn_num_layers
int
default:"2"
Number of GNN convolutional layers. Range: 2-5 recommended.Location: main.py:45Note: Each layer is followed by GraphNorm, activation, and dropout.
--gnn_activation
str
default:"elu"
Activation function for GNN layers.Options:
  • elu: Exponential Linear Unit (default, smooth negatives)
  • relu: Rectified Linear Unit (standard)
  • leaky_relu: Leaky ReLU (prevents dead neurons)
  • gelu: Gaussian Error Linear Unit (smooth, used in transformers)
Location: main.py:46Implementation: model.py:27-34

Pooling

--use_topk_pooling
bool
default:"true"
Use TopK pooling instead of global mean/max pooling. TopK selects the most important nodes at each layer.Location: main.py:41Action: store_true (enabled by default)
--topk_ratio
float
default:"0.3"
Fraction of nodes to keep in TopK pooling. Range: (0, 1].Location: main.py:42Safeguard: Automatically clamped to minimum 0.3 to prevent empty graphs (main.py:66-68, model.py:74).

Transfer Learning

--freeze_encoder
bool
default:"false"
Freeze GNN encoder weights during temporal training. Use when you have a pretrained encoder.Location: main.py:40Action: store_true (include flag to enable)Pretrained path: {save_path}/pretrained_gnn_encoder.pth (created via supervised_pretrain.py)

Temporal Features

--use_time_features
bool
default:"false"
Enable time-aware prediction by incorporating temporal gaps into the model.Location: main.py:47Action: store_true (include flag to enable)Effect: Projects time-to-predict into 32D embedding and fuses with graph features (512D + 32D → 512D).
--time_normalization
str
default:"log"
Method for normalizing temporal gaps before feeding to the model.Options:
  • log: log(1 + months/12) - compresses long time gaps (default)
  • minmax: Linear scaling to [0, 1]
  • buckets: Discretize into time bins
  • raw: No normalization (use months directly)
Location: main.py:49Implementation: temporal_gap_processor.py
--single_visit_horizon
int
default:"6"
Default prediction horizon in months for subjects with only one visit.Location: main.py:50Fallback: Used when months_to_next field is missing or invalid.

Cross-Validation

--num_folds
int
default:"5"
Number of cross-validation folds. Set to 1 for single train/val/test split.Location: main.py:37Split ratio: 80% train+val (further split 80/20 → 64% train, 16% val), 20% testStratification: Ensures balanced converter/stable distribution across folds.

Example Configurations

Default Training

python main.py

High-Capacity Model

python main.py \
  --gnn_hidden_dim 512 \
  --gnn_num_layers 4 \
  --lstm_hidden_dim 128 \
  --lstm_num_layers 2

GAT with Time Features

python main.py \
  --layer_type GAT \
  --use_time_features \
  --exclude_target_visit \
  --time_normalization log

Transfer Learning from Pretrained

python main.py \
  --freeze_encoder \
  --lr 0.0005 \
  --n_epochs 50

Strong Class Imbalance Handling

python main.py \
  --focal_alpha 0.95 \
  --focal_gamma 4.0 \
  --minority_focus_epochs 30 \
  --label_smoothing 0.1
All boolean arguments use store_true action. Include the flag to enable (e.g., --freeze_encoder), omit to disable.
When using --use_topk_pooling, ensure --topk_ratio is not too small (< 0.1) as this can lead to empty graphs after pooling. The system enforces a minimum of 0.3.

Build docs developers (and LLMs) love