gpr_alpha_examples.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def plot(m, Xtrain, ytrain):
    xx = np.linspace(-0.5, 1.5, 100)[:, None]
    mean, var = m.predict_y(xx)
    mean = np.reshape(mean, (xx.shape[0], 1))
    var = np.reshape(var, (xx.shape[0], 1))
    if isinstance(m, aep.SDGPR):
        zu = m.sgp_layers[0].zu
    elif isinstance(m, vfe.SGPR_collapsed):
        zu = m.zu
    else:
        zu = m.sgp_layer.zu
    mean_u, var_u = m.predict_f(zu)
    plt.figure()
    plt.plot(Xtrain, ytrain, 'kx', mew=2)
    plt.plot(xx, mean, 'b', lw=2)
    # pdb.set_trace()
    plt.fill_between(
        xx[:, 0],
        mean[:, 0] - 2 * np.sqrt(var[:, 0]),
        mean[:, 0] + 2 * np.sqrt(var[:, 0]),
        color='blue', alpha=0.2)
    plt.errorbar(zu, mean_u, yerr=2 * np.sqrt(var_u), fmt='ro')
    plt.xlim(-0.1, 1.1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号