comparisonfigs.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:rnnlab 作者: phueb 项目源码 文件源码
def make_test_and_train_pp_traj_fig(models1, models2=None, palette=None, ):
    """
    Returns fig showing trajectory of test and train perplexity
    """
    start = time.time()
    sns.set_style('white')
    # load data
    xys = []
    model_groups = [models1] if models2 is None else [models1, models2]
    for n, models in enumerate(model_groups):
        model_test_pp_trajs = []
        model_train_pp_trajs = []
        for model in models:
            model_test_pp_trajs.append(model.get_traj('test_pp'))
            model_train_pp_trajs.append(model.get_traj('train_pp'))
        x = models[0].get_data_step_axis()
        traj_mat1 = np.asarray([traj[:len(x)] for traj in model_test_pp_trajs])
        traj_mat2 = np.asarray([traj[:len(x)] for traj in model_train_pp_trajs])
        y1 = np.mean(traj_mat1, axis=0)
        y2 = np.mean(traj_mat2, axis=0)
        sem1 = [stats.sem(row) for row in np.asarray(traj_mat1).T]
        sem2 = [stats.sem(row) for row in np.asarray(traj_mat2).T]
        xys.append((x, y1, y2, sem1, sem2))
    # fig
    fig, ax = plt.subplots(figsize=(FigsConfigs.MAX_FIG_WIDTH, 3))
    ax.set_ylim([0, models1[0].terms.num_set_])
    ax.set_ylabel('Perplexity', fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(axis='both', which='both', top='off', right='off')
    ax.set_xlabel('Mini Batch', fontsize=FigsConfigs.AXLABEL_FONT_SIZE)
    ax.xaxis.set_major_formatter(FuncFormatter(human_format))
    ax.yaxis.grid(True)
    # plot
    for (x, y1, y2, sem1, sem2) in xys:
        color = next(palette) if palette is not None else 'black'
        ax.plot(x, y1, '-', linewidth=FigsConfigs.LINEWIDTH, color=color, linestyle='-', label='Test')
        ax.plot(x, y2, '-', linewidth=FigsConfigs.LINEWIDTH, color=color, linestyle='--', label='Train')
        ax.fill_between(x, np.add(y1, sem1), np.subtract(y1, sem1), alpha=FigsConfigs.FILL_ALPHA, color='grey')
        ax.fill_between(x, np.add(y2, sem2), np.subtract(y2, sem2), alpha=FigsConfigs.FILL_ALPHA, color='grey')
    plt.tight_layout()
    plt.legend(loc='best')
    print('{} completed in {:.1f} secs'.format(sys._getframe().f_code.co_name, time.time() - start))
    return fig
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号