Overview
The evaluator module provides functions for running inference on test sets and computing comprehensive evaluation metrics including confusion matrices, precision, recall, F1 scores, and classification reports.
Functions
run_test_evaluation
Runs inference on the test set using a trained model checkpoint and computes comprehensive evaluation metrics.
def run_test_evaluation(
experiment_id: str,
model_config: dict,
dataset_config: dict,
) -> dict
Unique identifier for the experiment. Used to locate the best checkpoint.
Model configuration dictionary with keys:
architecture: Model architecture name
num_classes: Number of output classes
pretrained: Whether to use pretrained weights
- Other architecture-specific parameters
Dataset configuration dictionary with keys:
dataset_path: Path to dataset directory
selected_families: Optional list of malware families
preprocessing: Preprocessing settings
split: Train/val/test split configuration
Returns: Dictionary with evaluation results:
{
"confusion_matrix": [[...], [...], ...], # 2D array (num_classes x num_classes)
"classification_report": {...}, # Detailed per-class metrics
"class_names": ["class1", "class2", ...], # List of class names
"per_class": {
"precision": [0.95, 0.87, ...], # Per-class precision
"recall": [0.92, 0.89, ...], # Per-class recall
"f1": [0.93, 0.88, ...], # Per-class F1 scores
"support": [120, 95, ...] # Samples per class
},
"accuracy": 0.91, # Overall accuracy
"total_samples": 500 # Total test samples
}
Workflow
The function performs the following steps:
- Device Setup: Automatically selects the best available device (CUDA > MPS > CPU)
- Checkpoint Loading: Loads the best checkpoint for the experiment using
CheckpointManager
- Model Building: Constructs the model from config and loads trained weights
- DataLoader Creation: Creates test DataLoader with appropriate transforms
- Inference: Runs model inference on all test samples
- Metrics Computation: Calculates confusion matrix, classification report, and per-class metrics
Example Usage
Basic Evaluation
from training.evaluator import run_test_evaluation
# Define configurations
experiment_id = "exp_20240315_123456"
model_config = {
"architecture": "ResNet50",
"num_classes": 9,
"pretrained": True,
"freeze_backbone": False
}
dataset_config = {
"dataset_path": "dataset",
"selected_families": None,
"preprocessing": {
"target_size": (224, 224),
"normalization": "ImageNet Mean/Std",
"color_mode": "RGB"
},
"split": {
"train": 70,
"val": 15,
"test": 15,
"stratified": True,
"random_seed": 72
}
}
# Run evaluation
results = run_test_evaluation(
experiment_id=experiment_id,
model_config=model_config,
dataset_config=dataset_config
)
print(f"Test Accuracy: {results['accuracy']*100:.2f}%")
print(f"Total Samples: {results['total_samples']}")
Accessing Results
results = run_test_evaluation(experiment_id, model_config, dataset_config)
# Overall accuracy
print(f"Accuracy: {results['accuracy']*100:.2f}%")
# Per-class metrics
class_names = results["class_names"]
per_class = results["per_class"]
print("\nPer-Class Results:")
for i, class_name in enumerate(class_names):
print(f"{class_name}:")
print(f" Precision: {per_class['precision'][i]:.3f}")
print(f" Recall: {per_class['recall'][i]:.3f}")
print(f" F1 Score: {per_class['f1'][i]:.3f}")
print(f" Support: {per_class['support'][i]}")
# Confusion matrix
import numpy as np
cm = np.array(results["confusion_matrix"])
print(f"\nConfusion Matrix Shape: {cm.shape}")
Visualizing Confusion Matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
results = run_test_evaluation(experiment_id, model_config, dataset_config)
# Extract confusion matrix and class names
cm = np.array(results["confusion_matrix"])
class_names = results["class_names"]
# Normalize confusion matrix
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(
cm_normalized,
annot=True,
fmt='.2f',
cmap='Blues',
xticklabels=class_names,
yticklabels=class_names,
cbar_kws={'label': 'Proportion'}
)
plt.title(f'Confusion Matrix - {experiment_id}')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig(f'confusion_matrix_{experiment_id}.png', dpi=300)
plt.show()
Classification Report
import json
results = run_test_evaluation(experiment_id, model_config, dataset_config)
# Get detailed classification report
report = results["classification_report"]
print("\nDetailed Classification Report:")
print(json.dumps(report, indent=2))
# The report includes per-class and aggregate metrics:
# - precision, recall, f1-score, support for each class
# - macro avg: unweighted mean across classes
# - weighted avg: weighted by support
# - accuracy: overall accuracy
Comparing Multiple Experiments
from training.evaluator import run_test_evaluation
experiment_ids = [
"exp_resnet50",
"exp_efficientnet",
"exp_vit"
]
results_list = []
for exp_id in experiment_ids:
results = run_test_evaluation(
experiment_id=exp_id,
model_config=model_config,
dataset_config=dataset_config
)
results_list.append({
"experiment": exp_id,
"accuracy": results["accuracy"],
"avg_precision": np.mean(results["per_class"]["precision"]),
"avg_recall": np.mean(results["per_class"]["recall"]),
"avg_f1": np.mean(results["per_class"]["f1"])
})
# Print comparison
import pandas as pd
df = pd.DataFrame(results_list)
print("\nExperiment Comparison:")
print(df.to_string(index=False))
# Find best model
best_exp = df.loc[df["accuracy"].idxmax()]
print(f"\nBest Model: {best_exp['experiment']} (Acc: {best_exp['accuracy']*100:.2f}%)")
Analyzing Misclassifications
import numpy as np
results = run_test_evaluation(experiment_id, model_config, dataset_config)
cm = np.array(results["confusion_matrix"])
class_names = results["class_names"]
print("\nTop Misclassification Pairs:")
misclassifications = []
for i in range(len(class_names)):
for j in range(len(class_names)):
if i != j and cm[i, j] > 0:
misclassifications.append({
"true": class_names[i],
"predicted": class_names[j],
"count": int(cm[i, j])
})
# Sort by count
misclassifications.sort(key=lambda x: x["count"], reverse=True)
# Print top 10
for mc in misclassifications[:10]:
print(f" {mc['true']:15s} → {mc['predicted']:15s}: {mc['count']:3d} samples")
Saving Results
import json
from pathlib import Path
results = run_test_evaluation(experiment_id, model_config, dataset_config)
# Create results directory
results_dir = Path("evaluation_results")
results_dir.mkdir(exist_ok=True)
# Save as JSON
with open(results_dir / f"{experiment_id}_results.json", "w") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {results_dir / f'{experiment_id}_results.json'}")
# Save summary as text
with open(results_dir / f"{experiment_id}_summary.txt", "w") as f:
f.write(f"Experiment: {experiment_id}\n")
f.write(f"Test Accuracy: {results['accuracy']*100:.2f}%\n")
f.write(f"Total Samples: {results['total_samples']}\n\n")
f.write("Per-Class Results:\n")
for i, name in enumerate(results["class_names"]):
f.write(f"\n{name}:\n")
f.write(f" Precision: {results['per_class']['precision'][i]:.3f}\n")
f.write(f" Recall: {results['per_class']['recall'][i]:.3f}\n")
f.write(f" F1: {results['per_class']['f1'][i]:.3f}\n")
f.write(f" Support: {results['per_class']['support'][i]}\n")
Error Handling
The function will raise errors in the following cases:
try:
results = run_test_evaluation(experiment_id, model_config, dataset_config)
except ValueError as e:
if "No checkpoint found" in str(e):
print(f"Error: Experiment '{experiment_id}' has no saved checkpoints")
print("Make sure training completed successfully")
else:
raise
except Exception as e:
print(f"Evaluation failed: {e}")