def make_avg_traj_figs(model):
def make_avg_traj_fig(traj_name):
"""
Returns fig showing trajectory of Probes Perplexity
"""
start = time.time()
sns.set_style('white')
ylims = model.eval_name_range_dict[traj_name]
# load data
x = model.get_data_step_axis()
y = model.get_traj(traj_name)
# fig
fig, ax = plt.subplots(figsize=(FigsConfigs.MAX_FIG_WIDTH, 3), dpi=FigsConfigs.DPI)
ax.set_ylim(ylims)
ax.set_xlabel('Mini Batch', fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.set_ylabel(traj_name, fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(axis='both', which='both', top='off', right='off')
ax.xaxis.set_major_formatter(FuncFormatter(human_format))
ax.yaxis.grid(True)
# plot
ax.plot(x, y, '-', linewidth=FigsConfigs.LINEWIDTH, color='black')
plt.tight_layout()
print('{} completed in {:.1f} secs'.format(sys._getframe().f_code.co_name, time.time() - start))
return fig
figs = [make_avg_traj_fig(traj_name) for traj_name in AppConfigs.EVAL_NAMES]
return figs
评论列表
文章目录