def plot_averages_by_distance(df, figname, fignum, use_en_source=True, num_accs=24, pointplot=True, hue='Distance'):
plt.figure(fignum)
if use_en_source:
df_side = df[(df.source == 'en') & (df.target != 'en')]
layers = np.concatenate([[i]*5 for i in range(5)] * num_accs)
else:
df_side = df[(df.source != 'en') & (df.target == 'en')]
layers = list(range(5))*5*num_accs
accs = get_accs_from_df(df_side, col_pref='dist')
flat_accs = np.concatenate(accs)
dists = np.concatenate([[pretty_dist_names_list[i]]*75 for i in range(8)])
df_plot = pd.DataFrame({'Layer' : layers, 'Accuracy' : flat_accs, 'Distance' : dists })
#print(df_plot)
plotfunc = sns.pointplot if pointplot else sns.boxplot
if hue == 'Distance':
plotfunc(x='Layer', y='Accuracy', data=df_plot, hue='Distance')
else:
plotfunc(x='Distance', y='Accuracy', data=df_plot, hue='Layer')
plt.xticks(range(8), pretty_dist_names_list)
plt.tight_layout()
plt.savefig(figname)
return fignum + 1
评论列表
文章目录