tools.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:AutoSleepScorerDev 作者: skjerns 项目源码 文件源码
def plot_difference_matrix(fname, confmat1, confmat2, target_names, 
                          title='', cmap='Blues', perc=True,figsize=[5,4],cbar=True,
                          **kwargs):
    """Plot Confusion Matrix."""
    figsize = deepcopy(figsize)
    if cbar == False:
        figsize[0] = figsize[0] - 0.6

    cm1 = confmat1
    cm2 = confmat2
    cm1 = 100 * cm1.astype('float') / cm1.sum(axis=1)[:, np.newaxis]
    cm2 = 100 * cm2.astype('float') / cm2.sum(axis=1)[:, np.newaxis]
    cm = cm2 - cm1
    cm_eye = np.zeros_like(cm)
    cm_eye[np.eye(len(cm_eye), dtype=bool)] = cm.diagonal()
    df = pd.DataFrame(data=cm_eye, columns=target_names, index=target_names)
    plt.figure(figsize=figsize)
    g  = sns.heatmap(df, annot=cm, fmt=".1f" ,
                     linewidths=.5, vmin=-10, vmax=10, 
                     cmap='coolwarm_r',annot_kws={"size": 13},cbar=cbar,**kwargs)#sns.diverging_palette(20, 220, as_cmap=True))    
    g.set_title(title)
    g.set_ylabel('True sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
    g.set_xlabel('Predicted sleep stage',fontdict={'fontsize' : 12, 'fontweight':'bold'})
    plt.tight_layout()

    g.figure.savefig(os.path.join('plots', fname))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号