def generate_confusion_matrix(y_test, y_pred, labels, title, filename, show=False):
cm = confusion_matrix(y_test, y_pred, labels=labels)
df_cm = pd.DataFrame(cm, index=labels, columns=labels)
plt.figure(figsize=(12,8))
ax = sn.heatmap(df_cm, annot=True)
plt.ylabel("Actual Label", fontsize=14, fontweight='bold')
plt.xlabel("Predicted Label", fontsize=14, fontweight='bold')
plt.title(title, fontsize=16, fontweight='bold')
ttl = ax.title
ttl.set_position([0.5, 1.03])
plt.savefig(filename)
if show:
plt.show()
评论列表
文章目录