plot_journal.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:nmt-repr-analysis 作者: boknilev 项目源码 文件源码
def plot_pair_by_layer(ax, layers, all_accs, maj, mfl, title, hide_xlabel=False, hide_ylabel=False, 
                       ymin=0, ymax=100, plot_maj=True, nbins=6, delta_above=True, delta_val=4):

    # compute stats
    means = np.mean(all_accs, axis=0)
    stds = np.std(all_accs, axis=0)
    maxs = np.max(all_accs, axis=0)
    mins = np.max(all_accs, axis=0)
    deltas = [0] + [means[i+1]-means[i] for i in range(len(means)-1)]

    num_runs = len(all_accs)
    flat_accs = np.concatenate(all_accs)
    df = pd.DataFrame({'Layer' : [0,1,2,3,4]*num_runs, 'Accuracy' : flat_accs })
    ax.set_ylim(ymin,ymax)
    sns.swarmplot(x='Layer', y='Accuracy', data=df, ax=ax)
    if hide_xlabel:
        ax.set_xlabel('')
    if hide_ylabel:
        ax.set_ylabel('')
    if plot_maj:
        maj_line = ax.axhline(y=maj, label='Majority', linestyle='--', color='black')
    else:
        maj_line = None
    mfl_line = ax.axhline(y=mfl, label='MFL', linestyle='-.', color='black')

    for i in range(len(deltas)):
        if delta_above:
            x, y = i, maxs[i] + delta_val
        else:
            x, y = i, mins[i] - delta_val*2
        str_val = '{:+.1f} ({:.1f})'.format(deltas[i], stds[i])
        ax.text(x, y, str_val, horizontalalignment='center', size='small')
    xmin, xmax = plt.xlim()
    #ax.text(xmax-0.4, maj+1, 'maj', horizontalalignment='left', size='medium')
    #ax.text(xmax-0.4, mfl+1, 'mfl', horizontalalignment='left', size='medium')

    ax.locator_params(axis='y', nbins=nbins) 

    ax.set_title(title)
    #ax.tight_layout()
    #plt.savefig(figname)

    return maj_line, mfl_line
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号