def make_probes_pp_traj_fig(models1, models2=None, palette=None):
"""
Returns fig showing trajectory of Probes Perplexity
"""
start = time.time()
sns.set_style('white')
# load data
xys = []
model_groups = [models1] if models2 is None else [models1, models2]
for n, models in enumerate(model_groups):
probes_pp_trajs_w = []
probes_pp_trajs_uw = []
for nn, model in enumerate(models):
probes_pp_trajs_w.append(model.get_traj('probes_pp'))
probes_pp_trajs_uw.append(model.get_traj('probes_pp_uw'))
x = models[0].get_data_step_axis()
traj_mat1 = np.asarray([traj[:len(x)] for traj in probes_pp_trajs_w])
traj_mat2 = np.asarray([traj[:len(x)] for traj in probes_pp_trajs_uw])
y1 = np.mean(traj_mat1, axis=0)
y2 = np.mean(traj_mat2, axis=0)
xys.append((x, y1, y2))
# fig
fig, ax = plt.subplots(figsize=(FigsConfigs.MAX_FIG_WIDTH, 3))
ylabel = 'Probes Perplexity'
ax.set_ylabel(ylabel, fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(axis='both', which='both', top='off', right='off')
ax.set_xlabel('Mini Batch', fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.xaxis.set_major_formatter(FuncFormatter(human_format))
ax.yaxis.grid(True)
# plot
for (x, y1, y2) in xys:
color = next(palette) if palette is not None else 'black'
ax.plot(x, y1, '-', linewidth=FigsConfigs.LINEWIDTH, color=color, linestyle='-', label='weighted')
ax.plot(x, y2, '-', linewidth=FigsConfigs.LINEWIDTH, color=color, linestyle='--', label='unweighted')
plt.legend(loc='best')
plt.tight_layout()
print('{} completed in {:.1f} secs'.format(sys._getframe().f_code.co_name, time.time() - start))
return fig
评论列表
文章目录