Skip to main content

Overview

The train.py module provides the main training pipeline with experiment tracking, dataset validation, and checkpoint management. It supports both real datasets (Fashion-MNIST) and synthetic data generation.

run_experiment

Main function to run a complete training experiment with tracking and checkpointing.
def run_experiment(config_name: str)

Parameters

config_name
string
required
Name of a predefined experiment configuration from EXPERIMENT_CONFIGS, or path to a JSON configuration file

Configuration Options

The experiment configuration (JSON or predefined) supports the following fields:
layer_sizes
list[int]
required
Architecture specification, e.g., [784, 64, 10] for input, hidden, and output layers
activations
list[string]
required
Activation functions for each layer transition, e.g., ["relu", "softmax"]
epochs
int
default:"3"
Number of training epochs
alpha
float
default:"0.1"
Learning rate
batch_size
int
default:"32"
Training batch size
seed
int
default:"42"
Random seed for reproducibility
precision
string
default:"float32"
Numeric precision: "float32", "float16", or "int8"
val_ratio
float
default:"0.1"
Proportion of training data to use for validation
synthetic_mode
bool
default:"false"
If true, generates synthetic random data instead of loading real dataset
synthetic_samples
int
default:"512"
Number of synthetic samples to generate (only used when synthetic_mode=true)
dataset_path
string
Custom path to training dataset (defaults to Fashion-MNIST train path)
dataset_min_rows
int
default:"100"
Minimum required rows for dataset validation
dataset_auto_prepare
bool
default:"false"
Automatically download dataset if missing
dataset_sha256
string
Expected SHA-256 hash for dataset integrity verification

Behavior

  1. Configuration Resolution: Loads config from EXPERIMENT_CONFIGS or JSON file
  2. Seed Initialization: Sets global seed for reproducibility
  3. Model Creation: Instantiates NeuralNetwork with specified architecture
  4. Data Loading: Loads real dataset or generates synthetic data
  5. Train/Val Split: Splits data according to val_ratio
  6. Experiment Tracking: Creates experiment record with metadata
  7. Training: Runs model training with validation
  8. Checkpoint Saving: Saves model weights to experiments/checkpoints/
  9. Logging: Writes experiment history to experiments/logs/

Output Files

  • Checkpoint: experiments/checkpoints/{experiment_id}_v{version}.npz
  • History Log: experiments/logs/{experiment_id}.json

Example Usage

from train import run_experiment

# Using predefined config
run_experiment("baseline")

# Using custom JSON config
run_experiment("configs/my_experiment.json")

CLI Usage

Run training from the command line:
python train.py --experiment baseline
python train.py --experiment configs/custom.json

CLI Arguments

--experiment
string
required
Experiment config name or path to JSON config file

Example Output

[dataset] loading data/fashion_mnist_train.npz
Experiment logged: baseline_20260304 v1
History file: experiments/logs/baseline_20260304.json
Checkpoint: experiments/checkpoints/baseline_20260304_v1.npz
  • inference - Run inference on trained models
  • benchmark - Performance benchmarking utilities

Build docs developers (and LLMs) love