tools.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:AutoSleepScorerDev 作者: skjerns 项目源码 文件源码
def plot_confusion_matrix(fname, conf_mat, target_names, 
                          title='', cmap='Blues', perc=True,figsize=[6,5],cbar=True):
    """Plot Confusion Matrix."""
    figsize = deepcopy(figsize)
    if cbar == False:
        figsize[0] = figsize[0] - 0.6
    c_names = []
    r_names = []
    if len(target_names) != len(conf_mat):
        target_names = [str(i) for  i in np.arange(len(conf_mat))]
    for i, label in enumerate(target_names):
        c_names.append(label + '\n(' + str(int(np.sum(conf_mat[:,i]))) + ')')
        align = len(str(int(np.sum(conf_mat[i,:])))) + 3 - len(label)
        r_names.append('{:{align}}'.format(label, align=align) + '\n(' + str(int(np.sum(conf_mat[i,:]))) + ')')

    cm = conf_mat
    cm = 100* cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    df = pd.DataFrame(data=np.sqrt(cm), columns=c_names, index=r_names)
    if fname != '':plt.figure(figsize=figsize)
    g  = sns.heatmap(df, annot = cm if perc else conf_mat , fmt=".1f" if perc else ".0f",
                     linewidths=.5, vmin=0, vmax=np.sqrt(100), cmap=cmap, cbar=cbar,annot_kws={"size": 13})    
    g.set_title(title)
    if cbar:
        cbar = g.collections[0].colorbar
        cbar.set_ticks(np.sqrt(np.arange(0,100,20)))
        cbar.set_ticklabels(np.arange(0,100,20))
    g.set_ylabel('True sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
    g.set_xlabel('Predicted sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
#    plt.tight_layout()
    if fname!='':
        plt.tight_layout()
        g.figure.savefig(os.path.join('plots', fname))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号