Skip to main content

Overview

The engine module provides the TrainingEngine class that handles the complete training lifecycle including training loops, validation, metrics computation, learning rate scheduling, early stopping, and progress callbacks.

Classes

TrainingEngine

Manages the full training loop with comprehensive metrics tracking and control flow.
class TrainingEngine:
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: nn.Module,
        device: torch.device,
        scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
        early_stopping_patience: int = 0,
        checkpoint_callback: Callable | None = None,
        batch_callback: Callable[[int, int, dict], None] | None = None,
    )
model
nn.Module
required
PyTorch model to train
train_loader
DataLoader
required
DataLoader for training data
val_loader
DataLoader
required
DataLoader for validation data
optimizer
torch.optim.Optimizer
required
Optimizer for updating model parameters
criterion
nn.Module
required
Loss function (e.g., CrossEntropyLoss, FocalLoss)
device
torch.device
required
Device to run training on (cuda, mps, or cpu)
scheduler
torch.optim.lr_scheduler.LRScheduler
default:"None"
Optional learning rate scheduler
early_stopping_patience
int
default:"0"
Number of epochs without improvement before stopping (0 = disabled)
checkpoint_callback
Callable
default:"None"
Optional callback function called after each epoch: callback(epoch, metrics, is_best)
batch_callback
Callable
default:"None"
Optional callback function called every 10 batches: callback(batch_idx, total_batches, metrics)

Attributes

Training State
  • current_epoch (int): Current epoch number
  • best_val_loss (float): Best validation loss achieved
  • best_epoch (int): Epoch with best validation loss
  • epochs_without_improvement (int): Counter for early stopping
  • should_stop (bool): Flag to signal training stop
  • is_paused (bool): Flag indicating if training is paused
History
  • history (dict): Contains lists of metrics for each epoch:
    • train_loss: Training loss per epoch
    • train_acc: Training accuracy per epoch
    • train_precision: Training precision (macro) per epoch
    • train_recall: Training recall (macro) per epoch
    • train_f1: Training F1 score (macro) per epoch
    • val_loss: Validation loss per epoch
    • val_acc: Validation accuracy per epoch
    • val_precision: Validation precision (macro) per epoch
    • val_recall: Validation recall (macro) per epoch
    • val_f1: Validation F1 score (macro) per epoch
    • lr: Learning rate per epoch

Methods

train_epoch

Trains the model for one epoch.
def train_epoch(self) -> dict
Returns: Dictionary with training metrics:
  • train_loss: Average training loss
  • train_acc: Training accuracy
  • train_precision: Macro-averaged precision
  • train_recall: Macro-averaged recall
  • train_f1: Macro-averaged F1 score
Behavior:
  • Sets model to training mode
  • Iterates through training batches
  • Performs forward pass, backward pass, and optimizer step
  • Calls batch_callback every 10 batches if provided
  • Checks for should_stop flag to allow early termination
  • Computes comprehensive metrics using scikit-learn

validate

Evaluates the model on the validation set.
@torch.no_grad()
def validate(self) -> dict
Returns: Dictionary with validation metrics:
  • val_loss: Average validation loss
  • val_acc: Validation accuracy
  • val_precision: Macro-averaged precision
  • val_recall: Macro-averaged recall
  • val_f1: Macro-averaged F1 score
Behavior:
  • Sets model to evaluation mode
  • Disables gradient computation
  • Iterates through validation batches
  • Computes comprehensive metrics using scikit-learn

fit

Runs the complete training loop for the specified number of epochs.
def fit(
    self,
    epochs: int,
    update_callback: Callable[[int, dict], None] | None = None,
) -> dict
epochs
int
required
Number of epochs to train
update_callback
Callable
default:"None"
Optional callback called after each epoch with (epoch, metrics)
Returns: Dictionary with final training results:
  • final_epoch: Last completed epoch
  • best_epoch: Epoch with best validation loss
  • best_val_loss: Best validation loss achieved
  • duration: Training duration as formatted string (e.g., “5m 23s”)
  • history: Complete training history
Training Loop:
  1. Train for one epoch → train_epoch()
  2. Validate → validate()
  3. Update learning rate scheduler if provided
  4. Check if current epoch is best (lowest val_loss)
  5. Call checkpoint_callback if provided
  6. Print epoch summary
  7. Call update_callback if provided
  8. Check early stopping condition
  9. Handle pause/stop signals
Early Stopping:
  • Tracks epochs without improvement in validation loss
  • Stops training when epochs_without_improvement >= early_stopping_patience
  • Only active if early_stopping_patience > 0

stop

Signals the training loop to stop.
def stop(self)
Sets should_stop flag to True, causing training to halt at the next checkpoint.

pause

Pauses the training loop.
def pause(self)
Sets is_paused flag to True, causing training to wait before starting next epoch.

resume

Resumes a paused training loop.
def resume(self)
Sets is_paused flag to False, allowing training to continue.

Example Usage

Basic Training

import torch
import torch.nn as nn
from training.engine import TrainingEngine
from training.dataset import create_dataloaders
from training.optimizers import create_optimizer, create_criterion

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = YourModel(num_classes=9).to(device)

# Create dataloaders
dataloaders, class_names, class_weights = create_dataloaders(
    dataset_config,
    training_config
)

# Create optimizer and loss
optimizer = create_optimizer(model, training_config)
criterion = create_criterion(training_config, class_weights, device)

# Create engine
engine = TrainingEngine(
    model=model,
    train_loader=dataloaders["train"],
    val_loader=dataloaders["val"],
    optimizer=optimizer,
    criterion=criterion,
    device=device
)

# Train
results = engine.fit(epochs=50)
print(f"Best epoch: {results['best_epoch']}")
print(f"Best val loss: {results['best_val_loss']:.4f}")

With Learning Rate Scheduler

from training.optimizers import create_scheduler

scheduler = create_scheduler(
    optimizer,
    training_config,
    steps_per_epoch=len(dataloaders["train"])
)

engine = TrainingEngine(
    model=model,
    train_loader=dataloaders["train"],
    val_loader=dataloaders["val"],
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    scheduler=scheduler  # Added scheduler
)

With Early Stopping

engine = TrainingEngine(
    model=model,
    train_loader=dataloaders["train"],
    val_loader=dataloaders["val"],
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    early_stopping_patience=10  # Stop if no improvement for 10 epochs
)

With Callbacks

def checkpoint_callback(epoch, metrics, is_best):
    """Save checkpoint after each epoch"""
    if is_best:
        torch.save(model.state_dict(), "best_model.pth")
        print(f"Saved best model at epoch {epoch}")

def update_callback(epoch, metrics):
    """Log metrics to external system"""
    wandb.log({"epoch": epoch, **metrics})

def batch_callback(batch_idx, total_batches, metrics):
    """Progress updates during training"""
    progress = batch_idx / total_batches * 100
    print(f"Progress: {progress:.1f}% | Loss: {metrics['batch_loss']:.4f}")

engine = TrainingEngine(
    model=model,
    train_loader=dataloaders["train"],
    val_loader=dataloaders["val"],
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    checkpoint_callback=checkpoint_callback,
    batch_callback=batch_callback
)

results = engine.fit(epochs=50, update_callback=update_callback)

Interactive Training Control

import threading

def train_in_thread():
    results = engine.fit(epochs=100)
    return results

# Start training in background
training_thread = threading.Thread(target=train_in_thread)
training_thread.start()

# Control training interactively
time.sleep(10)
engine.pause()  # Pause training
print("Training paused")

time.sleep(5)
engine.resume()  # Resume training
print("Training resumed")

# Stop training early if needed
if some_condition:
    engine.stop()
    print("Training stopped")

training_thread.join()

Accessing Training History

results = engine.fit(epochs=50)

# Access full history
history = results["history"]

# Plot training curves
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.plot(history["train_loss"], label="Train")
plt.plot(history["val_loss"], label="Val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(history["train_acc"], label="Train")
plt.plot(history["val_acc"], label="Val")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(history["lr"])
plt.xlabel("Epoch")
plt.ylabel("Learning Rate")

plt.tight_layout()
plt.savefig("training_curves.png")

Build docs developers (and LLMs) love