def display_confusion_matrix(test_data, test_labels, save=False):
"""
Plot a matrix representing the choices made by the network
on a testing batch.
X axis are the predicted values,
Y axis are the expected values.
If the flag save is set to True, the output will be saved
in a .png image.
"""
expected = test_labels
predicted = mnist.predict(test_data)
cm = confusion_matrix(expected, predicted)
plt.matshow(cm)
plt.title('Confusion matrix')
plt.colorbar()
plt.ylabel('Expected label')
plt.xlabel('Predicted label')
plt.show()
if save is True:
plt.savefig("../results/mnist/confusion_matrix.png")
评论列表
文章目录