The ConfusionMatrix class template stores and updates a confusion matrix for multi-class classification problems. It tracks predicted vs. actual class labels and provides utilities for computing metrics.
Template parameters
T
typename
default:"std::size_t"
Arithmetic type for storing counts. Must satisfy the Arithmetic concept (any arithmetic type like int, size_t, float).
Label
typename
default:"std::size_t"
Type of class labels. Can be any integer-like type that is convertible to std::size_t.
Type aliases
The arithmetic type used for matrix elements.
Constructor
explicit ConfusionMatrix(std::size_t num_classes);
Creates a confusion matrix for a given number of classes.
Number of classes in the classification problem. Must be greater than 0.
Throws std::invalid_argument if num_classes is 0.
Static factory methods
fixed
template<std::size_t N>
static constexpr ConfusionMatrix fixed();
Creates a confusion matrix with compile-time known number of classes.
Number of classes (compile-time constant)
New confusion matrix with N classes
Methods
update
void update(const Label& y_true, const Label& y_pred) noexcept;
Updates the matrix with a single prediction.
Increments the cell at position [y_true][y_pred] by 1. Silently ignores invalid labels (outside the range [0, num_classes)).
clear
Resets all matrix elements to zero.
num_classes
[[nodiscard]] std::size_t num_classes() const noexcept;
Number of classes in the matrix
data
[[nodiscard]] const std::vector<std::vector<T>>& data() const noexcept;
return
const std::vector<std::vector<T>>&
Reference to the underlying 2D matrix data
operator[]
[[nodiscard]] const std::vector<T>& operator[](std::size_t i) const noexcept;
[[nodiscard]] std::vector<T>& operator[](std::size_t i) noexcept;
Provides row access to the matrix.
Row index (true class label)
Reference to row i of the matrix
Matrix layout: M[true_class][predicted_class]
trace
[[nodiscard]] T trace() const noexcept;
Computes the sum of diagonal elements (total correct predictions).
Sum of all correctly classified samples
total
[[nodiscard]] T total() const noexcept;
Computes the sum of all matrix elements (total number of samples).
Total count of all samples
print
void print(std::ostream& os = std::cout, int width = 8) const;
Prints the confusion matrix to an output stream.
os
std::ostream&
default:"std::cout"
Output stream to write to
Column width for formatting
Example usage
#include <mlpp/model_validation/confusion_matrix.hpp>
using namespace mlpp::model_validation;
// Create a 3-class confusion matrix
ConfusionMatrix<std::size_t> cm(3);
// Simulate predictions
std::vector<std::size_t> y_true = {0, 0, 1, 1, 2, 2, 0, 1, 2};
std::vector<std::size_t> y_pred = {0, 1, 1, 1, 2, 0, 0, 2, 2};
// Update matrix with each prediction
for (size_t i = 0; i < y_true.size(); ++i) {
cm.update(y_true[i], y_pred[i]);
}
// Display the matrix
cm.print();
// Access individual elements
std::cout << "True class 0 predicted as class 1: "
<< cm[0][1] << std::endl;
// Compute accuracy
double accuracy = static_cast<double>(cm.trace()) / cm.total();
std::cout << "Accuracy: " << accuracy << std::endl;
// Clear for reuse
cm.clear();
Compile-time matrix
// Create matrix with compile-time known size
auto cm = ConfusionMatrix<int>::fixed<4>();
// Use as normal
cm.update(0, 0);
cm.update(1, 2);