Skip to main content

Overview

The evaluator module provides functions for running inference on test sets and computing comprehensive evaluation metrics including confusion matrices, precision, recall, F1 scores, and classification reports.

Functions

run_test_evaluation

Runs inference on the test set using a trained model checkpoint and computes comprehensive evaluation metrics.
def run_test_evaluation(
    experiment_id: str,
    model_config: dict,
    dataset_config: dict,
) -> dict
experiment_id
str
required
Unique identifier for the experiment. Used to locate the best checkpoint.
model_config
dict
required
Model configuration dictionary with keys:
  • architecture: Model architecture name
  • num_classes: Number of output classes
  • pretrained: Whether to use pretrained weights
  • Other architecture-specific parameters
dataset_config
dict
required
Dataset configuration dictionary with keys:
  • dataset_path: Path to dataset directory
  • selected_families: Optional list of malware families
  • preprocessing: Preprocessing settings
  • split: Train/val/test split configuration
Returns: Dictionary with evaluation results:
{
    "confusion_matrix": [[...], [...], ...],  # 2D array (num_classes x num_classes)
    "classification_report": {...},            # Detailed per-class metrics
    "class_names": ["class1", "class2", ...], # List of class names
    "per_class": {
        "precision": [0.95, 0.87, ...],        # Per-class precision
        "recall": [0.92, 0.89, ...],           # Per-class recall
        "f1": [0.93, 0.88, ...],               # Per-class F1 scores
        "support": [120, 95, ...]              # Samples per class
    },
    "accuracy": 0.91,                          # Overall accuracy
    "total_samples": 500                       # Total test samples
}

Workflow

The function performs the following steps:
  1. Device Setup: Automatically selects the best available device (CUDA > MPS > CPU)
  2. Checkpoint Loading: Loads the best checkpoint for the experiment using CheckpointManager
  3. Model Building: Constructs the model from config and loads trained weights
  4. DataLoader Creation: Creates test DataLoader with appropriate transforms
  5. Inference: Runs model inference on all test samples
  6. Metrics Computation: Calculates confusion matrix, classification report, and per-class metrics

Example Usage

Basic Evaluation

from training.evaluator import run_test_evaluation

# Define configurations
experiment_id = "exp_20240315_123456"

model_config = {
    "architecture": "ResNet50",
    "num_classes": 9,
    "pretrained": True,
    "freeze_backbone": False
}

dataset_config = {
    "dataset_path": "dataset",
    "selected_families": None,
    "preprocessing": {
        "target_size": (224, 224),
        "normalization": "ImageNet Mean/Std",
        "color_mode": "RGB"
    },
    "split": {
        "train": 70,
        "val": 15,
        "test": 15,
        "stratified": True,
        "random_seed": 72
    }
}

# Run evaluation
results = run_test_evaluation(
    experiment_id=experiment_id,
    model_config=model_config,
    dataset_config=dataset_config
)

print(f"Test Accuracy: {results['accuracy']*100:.2f}%")
print(f"Total Samples: {results['total_samples']}")

Accessing Results

results = run_test_evaluation(experiment_id, model_config, dataset_config)

# Overall accuracy
print(f"Accuracy: {results['accuracy']*100:.2f}%")

# Per-class metrics
class_names = results["class_names"]
per_class = results["per_class"]

print("\nPer-Class Results:")
for i, class_name in enumerate(class_names):
    print(f"{class_name}:")
    print(f"  Precision: {per_class['precision'][i]:.3f}")
    print(f"  Recall:    {per_class['recall'][i]:.3f}")
    print(f"  F1 Score:  {per_class['f1'][i]:.3f}")
    print(f"  Support:   {per_class['support'][i]}")

# Confusion matrix
import numpy as np
cm = np.array(results["confusion_matrix"])
print(f"\nConfusion Matrix Shape: {cm.shape}")

Visualizing Confusion Matrix

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

results = run_test_evaluation(experiment_id, model_config, dataset_config)

# Extract confusion matrix and class names
cm = np.array(results["confusion_matrix"])
class_names = results["class_names"]

# Normalize confusion matrix
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm_normalized,
    annot=True,
    fmt='.2f',
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    cbar_kws={'label': 'Proportion'}
)
plt.title(f'Confusion Matrix - {experiment_id}')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig(f'confusion_matrix_{experiment_id}.png', dpi=300)
plt.show()

Classification Report

import json

results = run_test_evaluation(experiment_id, model_config, dataset_config)

# Get detailed classification report
report = results["classification_report"]

print("\nDetailed Classification Report:")
print(json.dumps(report, indent=2))

# The report includes per-class and aggregate metrics:
# - precision, recall, f1-score, support for each class
# - macro avg: unweighted mean across classes
# - weighted avg: weighted by support
# - accuracy: overall accuracy

Comparing Multiple Experiments

from training.evaluator import run_test_evaluation

experiment_ids = [
    "exp_resnet50",
    "exp_efficientnet",
    "exp_vit"
]

results_list = []

for exp_id in experiment_ids:
    results = run_test_evaluation(
        experiment_id=exp_id,
        model_config=model_config,
        dataset_config=dataset_config
    )
    results_list.append({
        "experiment": exp_id,
        "accuracy": results["accuracy"],
        "avg_precision": np.mean(results["per_class"]["precision"]),
        "avg_recall": np.mean(results["per_class"]["recall"]),
        "avg_f1": np.mean(results["per_class"]["f1"])
    })

# Print comparison
import pandas as pd
df = pd.DataFrame(results_list)
print("\nExperiment Comparison:")
print(df.to_string(index=False))

# Find best model
best_exp = df.loc[df["accuracy"].idxmax()]
print(f"\nBest Model: {best_exp['experiment']} (Acc: {best_exp['accuracy']*100:.2f}%)")

Analyzing Misclassifications

import numpy as np

results = run_test_evaluation(experiment_id, model_config, dataset_config)

cm = np.array(results["confusion_matrix"])
class_names = results["class_names"]

print("\nTop Misclassification Pairs:")
misclassifications = []

for i in range(len(class_names)):
    for j in range(len(class_names)):
        if i != j and cm[i, j] > 0:
            misclassifications.append({
                "true": class_names[i],
                "predicted": class_names[j],
                "count": int(cm[i, j])
            })

# Sort by count
misclassifications.sort(key=lambda x: x["count"], reverse=True)

# Print top 10
for mc in misclassifications[:10]:
    print(f"  {mc['true']:15s}{mc['predicted']:15s}: {mc['count']:3d} samples")

Saving Results

import json
from pathlib import Path

results = run_test_evaluation(experiment_id, model_config, dataset_config)

# Create results directory
results_dir = Path("evaluation_results")
results_dir.mkdir(exist_ok=True)

# Save as JSON
with open(results_dir / f"{experiment_id}_results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"Results saved to {results_dir / f'{experiment_id}_results.json'}")

# Save summary as text
with open(results_dir / f"{experiment_id}_summary.txt", "w") as f:
    f.write(f"Experiment: {experiment_id}\n")
    f.write(f"Test Accuracy: {results['accuracy']*100:.2f}%\n")
    f.write(f"Total Samples: {results['total_samples']}\n\n")
    f.write("Per-Class Results:\n")
    for i, name in enumerate(results["class_names"]):
        f.write(f"\n{name}:\n")
        f.write(f"  Precision: {results['per_class']['precision'][i]:.3f}\n")
        f.write(f"  Recall:    {results['per_class']['recall'][i]:.3f}\n")
        f.write(f"  F1:        {results['per_class']['f1'][i]:.3f}\n")
        f.write(f"  Support:   {results['per_class']['support'][i]}\n")

Error Handling

The function will raise errors in the following cases:
try:
    results = run_test_evaluation(experiment_id, model_config, dataset_config)
except ValueError as e:
    if "No checkpoint found" in str(e):
        print(f"Error: Experiment '{experiment_id}' has no saved checkpoints")
        print("Make sure training completed successfully")
    else:
        raise
except Exception as e:
    print(f"Evaluation failed: {e}")

Build docs developers (and LLMs) love