import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
def make_confusion_matrix(cf, group_names=None, categories='auto', count=True,
percent=True, cbar=True, xyticks=True, xyplotlabels=True,
sum_stats=True, figsize=None, cmap='Blues', title=None):
"""
Generates a labeled confusion matrix using Seaborn's heatmap.
Args:
- cf (array-like): 2D array representing the confusion matrix.
- group_names (list): List of labels to use for annotating the matrix.
- categories (list or 'auto'): List of categories for axis labels.
- count (bool): Whether to display counts in matrix boxes.
- percent (bool): Whether to display percentage in matrix boxes.
- cbar (bool): Whether to display a colorbar.
- xyticks (bool): Whether to display x and y ticks.
- xyplotlabels (bool): Whether to display x and y labels.
- sum_stats (bool): Whether to display summary statistics.
- figsize (tuple): Size of the figure.
- cmap (str): Colormap to be used.
- title (str): Title of the heatmap.
Returns:
- None: Displays a heatmap.
"""
# Ensure matrix is square
if cf.shape[0] != cf.shape[1]:
raise ValueError("Confusion matrix should be a square matrix.")
# Ensure group_names match matrix size if provided
if group_names and len(group_names) != cf.size:
raise ValueError("Length of group_names should match the size of the confusion matrix.")
# Annotations for each box
blanks = ['' for _ in range(cf.size)]
group_labels = ["{}\n".format(value) for value in group_names] if group_names else blanks
group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()] if count else blanks
group_percentages = ["{0:.2%}".format(value) for value in cf.flatten() / np.sum(cf)] if percent else blanks
box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels, group_counts, group_percentages)]
box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1])
# Summary statistics
stats_text = ""
if sum_stats:
accuracy = np.trace(cf) / float(np.sum(cf))
if len(cf) == 2: # Binary classification metrics
precision = cf[1, 1] / sum(cf[:, 1])
recall = cf[1, 1] / sum(cf[1, :])
f1_score = 2 * precision * recall / (precision + recall)
stats_text = f"\n\nAccuracy={accuracy:.3f}\nPrecision={precision:.3f}\nRecall={recall:.3f}\nF1 Score={f1_score:.3f}"
else:
stats_text = f"\n\nAccuracy={accuracy:.3f}"
# Visualization
if not figsize:
figsize = plt.rcParams.get('figure.figsize')
plt.figure(figsize=figsize)
sns.heatmap(cf, annot=box_labels, fmt="", cmap=cmap, cbar=cbar, xticklabels=categories, yticklabels=categories)
if xyplotlabels:
plt.ylabel('True label')
plt.xlabel('Predicted label' + stats_text)
else:
plt.xlabel(stats_text)
if title:
plt.title(title)
plt.show()