visualisation.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:bnn-analysis 作者: myshkov 项目源码 文件源码
def plot_predictive_baseline(env, samples, stddev_mult=3., title_name=None):
    # single var regression only
    samples = samples.squeeze()

    train_x, train_y = env.get_train_x(), env.get_train_y()
    test_x, test_y = env.get_test_x(), env.get_test_y()

    pad_width = test_x.shape[0] - train_x.shape[0]
    train_x_padded = np.pad(train_x[:, 0], (0, pad_width), 'constant', constant_values=np.nan)
    train_y_padded = np.pad(train_y[:, 0], (0, pad_width), 'constant', constant_values=np.nan)

    data = samples

    df = pd.DataFrame.from_dict({
        'time': test_x[:, 0],
        'true_y': test_y[:, 0],
        'train_x': train_x_padded,
        'train_y': train_y_padded,
        'mean': data.mean(axis=0),
        'std': stddev_mult * data.std(axis=0),
        # 'stdn': 2. * (data.std(axis=0) + .5 ** .5),
    }).reset_index()

    g = sns.FacetGrid(df, size=9, aspect=1.8)

    g.map(plt.errorbar, 'time', 'mean', 'std', color=(0.7, 0.1, 0.1, 0.09))
    g.map(plt.plot, 'time', 'mean', color='b', lw=1)
    g.map(plt.plot, 'time', 'true_y', color='r', lw=1)
    g.map(plt.scatter, 'train_x', 'train_y', color='g', s=20)

    ax = g.ax
    ax.set_title('Posterior Predictive Distribution' + (': ' + title_name) if title_name is not None else '')
    ax.set(xlabel='X', ylabel='Y')
    ax.set_xlim(env.view_xrange[0], env.view_xrange[1])
    ax.set_ylim(env.view_yrange[0], env.view_yrange[1])

    legend = ['Prediction mean', 'True f(x)', 'Training data', 'StdDev']
    plt.legend(legend)

    # ax.annotate("MSE: {:.03f}".format(0), xy=(0.1, 0.9), xytext=(0.1, 0.9), xycoords='figure fraction',
    #             textcoords='figure fraction')

    name = utils.DATA_DIR.replace('/', '-')
    plt.tight_layout(pad=0.6)
    utils.save_fig('predictive-distribution-' + name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号