Confusion matrix plot one decimal but exact zeros

Question:

I’m trying to plot a confusion matrix that has one decimal for all values but if the value is exact zero, I would like to keep it as an exact zero rather than 0.0. How can I achieve this? In this minimal example for example I would like to keep 4.0 but instead of 0.0 I would like a 0

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

y_true = [0, 0, 0, 0, 1, 1, 1, 1]
y_pred = [0, 0, 0, 0, 1, 1, 1, 1]

disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix(y_true, y_pred),
display_labels=[0, 1])

disp.plot(include_values=True, cmap='Blues', xticks_rotation='vertical',
          values_format='.1f', ax=None, colorbar=True)

plt.show()

I tried to pass a custom function

def custom_formatter(value):
    if value == 0:
        return '{:.0f}'.format(value)
    else:
        return '{:.1f}'.format(value)

to

disp.plot(include_values=True, cmap='Blues', xticks_rotation='vertical', 
          values_format=values_format, ax=None, colorbar=True)

but this gives the error

TypeError: format() argument 2 must be str, not function

Thank you so much!

Asked By: scriptgirl_3000

||

Answers:

I think the easiest way is looping + check/replace text:

for iy, ix in np.ndindex(disp.text_.shape):
    txt = disp.text_[iy, ix]
    if txt.get_text() == "0.0":
        txt.set_text("0")

Confusion matrix, but instances of 0.0 have been replaced with 0 with no decimal point. Second image of a confusion matrix but with 6 rows and columns.

Answered By: Alexander L. Hayes