def plot_confusion_matrix(fname, conf_mat, target_names,
title='', cmap='Blues', perc=True,figsize=[6,5],cbar=True):
"""Plot Confusion Matrix."""
figsize = deepcopy(figsize)
if cbar == False:
figsize[0] = figsize[0] - 0.6
c_names = []
r_names = []
if len(target_names) != len(conf_mat):
target_names = [str(i) for i in np.arange(len(conf_mat))]
for i, label in enumerate(target_names):
c_names.append(label + '\n(' + str(int(np.sum(conf_mat[:,i]))) + ')')
align = len(str(int(np.sum(conf_mat[i,:])))) + 3 - len(label)
r_names.append('{:{align}}'.format(label, align=align) + '\n(' + str(int(np.sum(conf_mat[i,:]))) + ')')
cm = conf_mat
cm = 100* cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
df = pd.DataFrame(data=np.sqrt(cm), columns=c_names, index=r_names)
if fname != '':plt.figure(figsize=figsize)
g = sns.heatmap(df, annot = cm if perc else conf_mat , fmt=".1f" if perc else ".0f",
linewidths=.5, vmin=0, vmax=np.sqrt(100), cmap=cmap, cbar=cbar,annot_kws={"size": 13})
g.set_title(title)
if cbar:
cbar = g.collections[0].colorbar
cbar.set_ticks(np.sqrt(np.arange(0,100,20)))
cbar.set_ticklabels(np.arange(0,100,20))
g.set_ylabel('True sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
g.set_xlabel('Predicted sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
# plt.tight_layout()
if fname!='':
plt.tight_layout()
g.figure.savefig(os.path.join('plots', fname))
评论列表
文章目录