gpr_ep_examples.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def run_regression_1D_pep_inference():
    np.random.seed(42)

    print "create dataset ..."
    N = 200
    X = np.random.rand(N, 1)
    Y = np.sin(12 * X) + 0.5 * np.cos(25 * X) + np.random.randn(N, 1) * 0.2
    # plt.plot(X, Y, 'kx', mew=2)

    def plot(m):
        xx = np.linspace(-0.5, 1.5, 100)[:, None]
        mean, var = m.predict_f(xx)
        zu = m.sgp_layer.zu
        mean_u, var_u = m.predict_f(zu)
        plt.figure()
        plt.plot(X, Y, 'kx', mew=2)
        plt.plot(xx, mean, 'b', lw=2)
        plt.fill_between(
            xx[:, 0],
            mean[:, 0] - 2 * np.sqrt(var[:, 0]),
            mean[:, 0] + 2 * np.sqrt(var[:, 0]),
            color='blue', alpha=0.2)
        plt.errorbar(zu, mean_u, yerr=2 * np.sqrt(var_u), fmt='ro')
        plt.xlim(-0.1, 1.1)

    # inference
    print "create aep model and optimize ..."
    M = 20
    alpha = 0.5
    model_aep = aep.SGPR(X, Y, M, lik='Gaussian')
    model_aep.optimise(method='L-BFGS-B', alpha=alpha, maxiter=2000)
    # plot(model_aep)
    # plt.show()
    # plt.savefig('/tmp/gpr_aep_reg.pdf')

    start_time = time.time()
    model = pep.SGPR_rank_one(X, Y, M, lik='Gaussian')
    model.update_hypers(model_aep.get_hypers())
    # model.update_hypers(model.init_hypers())
    model.run_pep(np.arange(N), 10, alpha=alpha, parallel=False, compute_energy=True)
    end_time = time.time()
    print "sequential updates: %.4f" % (end_time - start_time)
    # plot(model)
    # plt.savefig('/tmp/gpr_pep_reg_seq.pdf')


    start_time = time.time()
    model = pep.SGPR_rank_one(X, Y, M, lik='Gaussian')
    model.update_hypers(model_aep.get_hypers())
    # model.update_hypers(model.init_hypers(Y))
    model.run_pep(np.arange(N), 10, alpha=alpha, parallel=True, compute_energy=False)
    end_time = time.time()
    print "parallel updates: %.4f" % (end_time - start_time)
    # plot(model)
    # plt.savefig('/tmp/gpr_pep_reg_par.pdf')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号