Source code for eobox.ml.plot

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

[docs]def plot_confusion_matrix(cm, class_names=None, switch_axes=False, vmin=None, vmax=None, cmap='RdBu', robust=True, logscale_color=False, fmt=",", annot_kws=None, cbar=False, mask_zeros=False, ax=None): """Plot a confusion matrix with the precision and recall added. TODO: documentation ... TODO: check https://github.com/wcipriano/pretty-print-confusion-matrix switch_axes if False the CM is returned as is the default of sklearn with the rows being the actual class and the columns the predicted class if True, the axis are switched. """ import matplotlib normalize=False ax = ax or plt.gca() if class_names is None: class_names = range(1, cm.shape[0] + 1) diagonal = cm.diagonal() n_samples = cm.sum().sum() col_wise_sums = cm.sum(axis=0) row_wise_sums = cm.sum(axis=1) precision = np.round((diagonal / row_wise_sums * 100), 2).astype(cm.dtype) recall = np.round((diagonal / col_wise_sums * 100), 0).astype(cm.dtype) precision_dim_expanded = np.expand_dims(np.concatenate((precision, [0])), axis=1) overall_accuracy = np.round((diagonal.sum() / n_samples) * 100, 0).astype(cm.dtype) cm = np.concatenate([cm, np.expand_dims(recall, axis=0)], axis=0) cm = np.concatenate([cm, precision_dim_expanded], axis=1) cm_annot = cm.copy() cm_annot[-1, -1] = overall_accuracy if logscale_color: cm = np.log(cm) cm[cm == -np.inf] = 0 # make only off-diagonals negative such that the heatmap colors are red, not blu cm = cm * -1 cm[np.diag_indices(cm.shape[0])] = cm.diagonal() * -1 cm[:, cm.shape[0] - 1] = cm[:, cm.shape[0] - 1] * -1 cm[cm.shape[0] - 1, :] = cm[cm.shape[0] - 1, :] * -1 # scale the precision and recall numbers such that the colors are in the range of the diagonals cm[:, -1] = np.round(max(diagonal) * (cm[:, -1] / 100), 0) cm[-1, :] = np.round(max(diagonal) * (cm[-1, :] / 100), 0) cm[-1, -1] = np.round(max(diagonal) * (overall_accuracy / 100), 0) #matplotlib.rcParams.update({'font.size': font_size}) if normalize: cm = cm.astype('float') / row_wise_sums[:, np.newaxis] else: pass labels_rec = np.concatenate((class_names, np.array(['Rec.']))) labels_prec = np.concatenate((class_names, np.array(['Prec.']))) if switch_axes: xticklabels = labels_rec yticklabels = labels_prec ylabel = 'Predicted' xlabel = 'Actual' cm = cm.transpose() cm_annot = cm_annot.transpose() else: xticklabels = labels_prec yticklabels = labels_rec ylabel = 'Actual' xlabel = 'Predicted' if mask_zeros: mask = cm == 0 else: mask = None ax = sns.heatmap(cm, vmin=vmin, vmax=vmax, cmap=cmap, center=0, fmt=fmt, cbar=cbar, robust=robust, xticklabels=xticklabels, yticklabels=yticklabels, annot=cm_annot, annot_kws=annot_kws, square=True, mask=mask, ax=ax) # fix matplotlib 3.1.1 bug # https://stackoverflow.com/questions/56942670/matplotlib-seaborn-first-and-last-row-cut-in-half-of-heatmap-plot bottom, top = ax.get_ylim() ax.set_ylim(bottom + 0.5, top - 0.5) tick_marks = np.arange(len(class_names)) ax.set_ylabel(ylabel) ax.set_xlabel(xlabel) return ax