def plot_pair_by_layer(ax, layers, all_accs, maj, mfl, title, hide_xlabel=False, hide_ylabel=False,
ymin=0, ymax=100, plot_maj=True, nbins=6, delta_above=True, delta_val=4):
# compute stats
means = np.mean(all_accs, axis=0)
stds = np.std(all_accs, axis=0)
maxs = np.max(all_accs, axis=0)
mins = np.max(all_accs, axis=0)
deltas = [0] + [means[i+1]-means[i] for i in range(len(means)-1)]
num_runs = len(all_accs)
flat_accs = np.concatenate(all_accs)
df = pd.DataFrame({'Layer' : [0,1,2,3,4]*num_runs, 'Accuracy' : flat_accs })
ax.set_ylim(ymin,ymax)
sns.swarmplot(x='Layer', y='Accuracy', data=df, ax=ax)
if hide_xlabel:
ax.set_xlabel('')
if hide_ylabel:
ax.set_ylabel('')
if plot_maj:
maj_line = ax.axhline(y=maj, label='Majority', linestyle='--', color='black')
else:
maj_line = None
mfl_line = ax.axhline(y=mfl, label='MFL', linestyle='-.', color='black')
for i in range(len(deltas)):
if delta_above:
x, y = i, maxs[i] + delta_val
else:
x, y = i, mins[i] - delta_val*2
str_val = '{:+.1f} ({:.1f})'.format(deltas[i], stds[i])
ax.text(x, y, str_val, horizontalalignment='center', size='small')
xmin, xmax = plt.xlim()
#ax.text(xmax-0.4, maj+1, 'maj', horizontalalignment='left', size='medium')
#ax.text(xmax-0.4, mfl+1, 'mfl', horizontalalignment='left', size='medium')
ax.locator_params(axis='y', nbins=nbins)
ax.set_title(title)
#ax.tight_layout()
#plt.savefig(figname)
return maj_line, mfl_line
评论列表
文章目录