def plot_confusion_matrix(y_ground, y_pred, title='Normalized confusion matrix', cmap=plt.cm.Blues):
print 'Ploting confusion matrix..'
# Compute confusion matrix
cm = confusion_matrix(y_ground, y_pred)
# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
# print(cm_normalized)
plt.figure()
plt.imshow(cm_normalized, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
评论列表
文章目录