plot_journal.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:nmt-repr-analysis 作者: boknilev 项目源码 文件源码
def plot_averages_by_distance(df, figname, fignum, use_en_source=True, num_accs=24, pointplot=True, hue='Distance'):

    plt.figure(fignum)
    if use_en_source:
        df_side = df[(df.source == 'en') & (df.target != 'en')]
        layers = np.concatenate([[i]*5 for i in range(5)] * num_accs)        
    else:
        df_side = df[(df.source != 'en') & (df.target == 'en')]
        layers = list(range(5))*5*num_accs

    accs = get_accs_from_df(df_side, col_pref='dist')
    flat_accs = np.concatenate(accs)
    dists = np.concatenate([[pretty_dist_names_list[i]]*75 for i in range(8)])
    df_plot = pd.DataFrame({'Layer' : layers, 'Accuracy' : flat_accs, 'Distance' : dists }) 
    #print(df_plot)
    plotfunc = sns.pointplot if pointplot else sns.boxplot
    if hue == 'Distance':
        plotfunc(x='Layer', y='Accuracy', data=df_plot, hue='Distance')
    else:
        plotfunc(x='Distance', y='Accuracy', data=df_plot, hue='Layer')
        plt.xticks(range(8), pretty_dist_names_list)


    plt.tight_layout()
    plt.savefig(figname)
    return fignum + 1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号