Overview
The experiments module provides comprehensive functions for training models, running optimization sweeps across pruning levels and quantization precisions, computing Pareto frontiers, and generating visualization plots.Functions
train_model
Trains a PyTorch model using the Adam optimizer and cross-entropy loss.Parameters
The PyTorch model to train. Can be any
torch.nn.Module instance.PyTorch DataLoader providing training batches. Must yield
(inputs, targets) tuples.Number of complete passes through the training dataset.
Learning rate for the Adam optimizer. Typical values: 0.001, 0.0001.
Device to train on (e.g.,
torch.device("cuda") or torch.device("cpu")).Returns
The trained model with updated weights. The same model instance that was passed in (modified in-place).
Implementation Details
Training loop implementation:- Moves model to specified device
- Sets model to training mode
- Creates Adam optimizer with specified learning rate
- Uses CrossEntropyLoss criterion
- For each epoch, iterates through all batches:
- Moves data to device
- Forward pass
- Computes loss
- Backpropagation
- Optimizer step
This function modifies the model in-place. If you need to preserve the original model, create a copy before training.
Example
run_sweep
Runs a comprehensive hyperparameter sweep across pruning levels and quantization precisions, collecting performance metrics for each configuration.Parameters
The trained base model to optimize. This model will be pruned and quantized in various configurations.
Validation DataLoader for accuracy evaluation. Must yield
(inputs, targets) tuples.DataLoader providing calibration data for INT8 quantization. Should contain representative samples.
Device to run benchmarks on (CPU or CUDA).
List of pruning levels to sweep. Each value should be in [0.0, 1.0). Example:
[0.0, 0.2, 0.4, 0.6, 0.8].List of precision formats to test. Valid values:
"fp32", "fp16", "int8".Device power consumption in watts for energy proxy calculation. Typical values: 1.0-5.0 for edge devices.
Number of calibration batches to use for INT8 quantization.
List of memory budget thresholds in MB to check for violations. Example:
[1.0, 2.0, 5.0].The active memory budget threshold in MB. Configurations exceeding this are marked as rejected (accepted=False).
Multiplier to scale measured latency (e.g., to simulate different hardware). Use 1.0 for no scaling.
Number of times to repeat latency benchmarks for statistical robustness.
Returns
A pandas DataFrame containing metrics for each configuration. Each row represents one configuration with the following columns:Configuration:
pruning_level: Pruning level applied (0.0 to <1.0)precision: Precision format used (“fp32”, “fp16”, “int8”)accepted: Boolean indicating if configuration meets active memory budgetactive_budget_mb: The active memory budget threshold used
accuracy: Model accuracy on validation set (0.0 to 1.0)latency_ms: Average inference latency in millisecondslatency_std_ms: Standard deviation of latencylatency_p95_ms: 95th percentile latencythroughput_sps: Throughput in samples per secondmemory_mb: Model memory footprint in megabytesenergy_proxy_j: Energy proxy in joules (latency_ms × power_watts / 1000)
violates_{budget}mb: Boolean for each budget in memory_budgets_mb
Implementation Details
The sweep process:- Iterate Configurations: For each combination of pruning level and precision
- Apply Pruning: Use
structured_channel_prunewith the pruning level - Apply Quantization: Convert to specified precision (fp32/fp16/int8)
- Collect Metrics: Run comprehensive benchmarks using
collect_metrics - Check Budgets: Determine if configuration is accepted and check violations
- Aggregate Results: Compile all results into a pandas DataFrame
The sweep can generate a large number of configurations. For
pruning_levels=[0.0, 0.2, 0.4, 0.6] and precisions=["fp32", "fp16", "int8"], you’ll get 4 × 3 = 12 configurations.Example
pareto_frontier
Computes the Pareto frontier of accepted configurations by selecting models that achieve the best accuracy for progressively increasing values of a constraint metric (latency or energy).Parameters
DataFrame of sweep results from
run_sweep. Must contain columns: accepted, accuracy, and the column specified in x_col.The constraint column name to optimize along (e.g.,
"latency_ms", "energy_proxy_j", "memory_mb"). Lower values of this metric are preferred.Returns
A DataFrame containing only the Pareto-optimal configurations. These are configurations where no other configuration achieves both better accuracy AND better constraint metric value.Properties:
- Only includes accepted configurations (where
accepted=True) - Sorted by increasing constraint metric (x_col)
- Each row represents a non-dominated solution
- Accuracy is strictly increasing along the frontier
Implementation Details
Pareto frontier algorithm:- Filter Accepted: Only consider configurations meeting the active memory budget
- Sort: Sort by constraint metric (ascending) and accuracy (descending)
- Select Non-Dominated: Iterate through sorted configurations:
- Keep configuration if it has better accuracy than all previous
- Track best accuracy seen so far
- Skip dominated configurations
A configuration is Pareto-optimal if there’s no other configuration that is strictly better in all objectives. This function implements a simple greedy algorithm for the accuracy-vs-constraint trade-off.
Example
save_plots
Generates and saves three visualization plots showing the trade-offs between accuracy and optimization metrics (latency, energy, memory).Parameters
Complete DataFrame of sweep results from
run_sweep. Must contain columns: accepted, accuracy, latency_ms, energy_proxy_j, memory_mb.Pareto frontier DataFrame for latency (from
pareto_frontier(df, "latency_ms")).Pareto frontier DataFrame for energy (from
pareto_frontier(df, "energy_proxy_j")).Directory path where plots will be saved. Will be created if it doesn’t exist.
Returns
No return value. Creates three PNG files in the output directory:accuracy_vs_latency.png
Scatter plot of accuracy vs latency with Pareto frontier overlay.
- Blue points: Accepted configurations
- Gray X markers: Rejected configurations (exceed memory budget)
- Red line: Pareto frontier
- Resolution: 180 DPI
accuracy_vs_energy.png
Scatter plot of accuracy vs energy proxy with Pareto frontier overlay.
- Green points: Accepted configurations
- Gray X markers: Rejected configurations
- Red line: Pareto frontier
- Resolution: 180 DPI
accuracy_vs_memory.png
Scatter plot of accuracy vs memory footprint.
- Purple points: Accepted configurations
- Gray X markers: Rejected configurations
- No Pareto frontier (memory is a hard constraint)
- Resolution: 180 DPI
Implementation Details
For each plot:- Split Data: Separate accepted and rejected configurations
- Create Figure: 7×5 inch figure with matplotlib
- Plot Points:
- Accepted: Colored circles with alpha=0.8
- Rejected: Gray X markers with alpha=0.5
- Plot Frontier: Red line connecting Pareto-optimal points (latency and energy plots only)
- Formatting: Labels, title, legend, tight layout
- Save: 180 DPI PNG file
The function automatically creates the output directory if it doesn’t exist using
output_dir.mkdir(parents=True, exist_ok=True).