visualisation.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:bnn-analysis 作者: myshkov 项目源码 文件源码
def plot_predictive_comparison(env, baseline_samples, target_samples, stddev_mult=3., target_metrics=None,
                               title_name=None):
    # single var regression only
    baseline_samples = baseline_samples.squeeze()
    target_samples = target_samples.squeeze()

    train_x, train_y = env.get_train_x(), env.get_train_y()
    test_x, test_y = env.get_test_x(), env.get_test_y()

    pad_width = test_x.shape[0] - train_x.shape[0]
    train_x_padded = np.pad(train_x[:, 0], (0, pad_width), 'constant', constant_values=np.nan)
    train_y_padded = np.pad(train_y[:, 0], (0, pad_width), 'constant', constant_values=np.nan)

    df = pd.DataFrame.from_dict({
        'time': test_x[:, 0],
        'true_y': test_y[:, 0],
        'train_x': train_x_padded,
        'train_y': train_y_padded,
        'mean': target_samples.mean(axis=0),
        'std': stddev_mult * target_samples.std(axis=0),
        'base_mean': baseline_samples.mean(axis=0),
        'base_std': stddev_mult * baseline_samples.std(axis=0),
    }).reset_index()

    g = sns.FacetGrid(df, size=9, aspect=1.8)

    g.map(plt.errorbar, 'time', 'base_mean', 'base_std', color=(0.7, 0.1, 0.1, 0.09))
    g.map(plt.errorbar, 'time', 'mean', 'std', color=(0.1, 0.1, 0.7, 0.09))
    g.map(plt.plot, 'time', 'mean', color='b', lw=1)
    g.map(plt.plot, 'time', 'true_y', color='r', lw=1)
    g.map(plt.scatter, 'train_x', 'train_y', color='g', s=20)

    ax = g.ax
    ax.set_title('Posterior Predictive Distribution' + (': ' + title_name) if title_name is not None else '')
    ax.set(xlabel='X', ylabel='Y')
    ax.set_xlim(env.view_xrange[0], env.view_xrange[1])
    ax.set_ylim(env.view_yrange[0], env.view_yrange[1])

    legend = ['Prediction mean', 'True f(x)', 'Training data', 'True StdDev', 'Predicted StdDev']
    plt.legend(legend)

    if target_metrics is not None:
        offset = 0
        for tm, tv in target_metrics.items():
            ax.annotate(f'{tm}: {tv:.02f}', xy=(0.08, 0.92 - offset), xytext=(0.08, 0.92 - offset),
                        xycoords='figure fraction', textcoords='figure fraction')
            offset += 0.04

    name = utils.DATA_DIR.replace('/', '-')
    plt.tight_layout(pad=0.6)
    utils.save_fig('predictive-distribution-' + name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号