Overview
Theengine module provides the TrainingEngine class that handles the complete training lifecycle including training loops, validation, metrics computation, learning rate scheduling, early stopping, and progress callbacks.
Classes
TrainingEngine
Manages the full training loop with comprehensive metrics tracking and control flow.PyTorch model to train
DataLoader for training data
DataLoader for validation data
Optimizer for updating model parameters
Loss function (e.g., CrossEntropyLoss, FocalLoss)
Device to run training on (cuda, mps, or cpu)
Optional learning rate scheduler
Number of epochs without improvement before stopping (0 = disabled)
Optional callback function called after each epoch:
callback(epoch, metrics, is_best)Optional callback function called every 10 batches:
callback(batch_idx, total_batches, metrics)Attributes
Training Statecurrent_epoch(int): Current epoch numberbest_val_loss(float): Best validation loss achievedbest_epoch(int): Epoch with best validation lossepochs_without_improvement(int): Counter for early stoppingshould_stop(bool): Flag to signal training stopis_paused(bool): Flag indicating if training is paused
history(dict): Contains lists of metrics for each epoch:train_loss: Training loss per epochtrain_acc: Training accuracy per epochtrain_precision: Training precision (macro) per epochtrain_recall: Training recall (macro) per epochtrain_f1: Training F1 score (macro) per epochval_loss: Validation loss per epochval_acc: Validation accuracy per epochval_precision: Validation precision (macro) per epochval_recall: Validation recall (macro) per epochval_f1: Validation F1 score (macro) per epochlr: Learning rate per epoch
Methods
train_epoch
Trains the model for one epoch.train_loss: Average training losstrain_acc: Training accuracytrain_precision: Macro-averaged precisiontrain_recall: Macro-averaged recalltrain_f1: Macro-averaged F1 score
- Sets model to training mode
- Iterates through training batches
- Performs forward pass, backward pass, and optimizer step
- Calls
batch_callbackevery 10 batches if provided - Checks for
should_stopflag to allow early termination - Computes comprehensive metrics using scikit-learn
validate
Evaluates the model on the validation set.val_loss: Average validation lossval_acc: Validation accuracyval_precision: Macro-averaged precisionval_recall: Macro-averaged recallval_f1: Macro-averaged F1 score
- Sets model to evaluation mode
- Disables gradient computation
- Iterates through validation batches
- Computes comprehensive metrics using scikit-learn
fit
Runs the complete training loop for the specified number of epochs.Number of epochs to train
Optional callback called after each epoch with
(epoch, metrics)final_epoch: Last completed epochbest_epoch: Epoch with best validation lossbest_val_loss: Best validation loss achievedduration: Training duration as formatted string (e.g., “5m 23s”)history: Complete training history
- Train for one epoch →
train_epoch() - Validate →
validate() - Update learning rate scheduler if provided
- Check if current epoch is best (lowest val_loss)
- Call
checkpoint_callbackif provided - Print epoch summary
- Call
update_callbackif provided - Check early stopping condition
- Handle pause/stop signals
- Tracks epochs without improvement in validation loss
- Stops training when
epochs_without_improvement >= early_stopping_patience - Only active if
early_stopping_patience > 0
stop
Signals the training loop to stop.should_stop flag to True, causing training to halt at the next checkpoint.
pause
Pauses the training loop.is_paused flag to True, causing training to wait before starting next epoch.
resume
Resumes a paused training loop.is_paused flag to False, allowing training to continue.
Example Usage
Basic Training
With Learning Rate Scheduler
With Early Stopping
With Callbacks
Interactive Training Control
Accessing Training History
Related
- Optimizer Utilities - Optimizer, scheduler, and loss creation
- Evaluator - Test set evaluation
- Dataset - DataLoader creation