spectral_mixture_gp_regression_test.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_spectral_mixture_gp_mean_abs_error():
    gp_model = SpectralMixtureGPModel()

    # Optimize the model
    gp_model.train()
    optimizer = optim.Adam(gp_model.parameters(), lr=0.1)
    optimizer.n_iter = 0

    gpytorch.functions.fastest = False
    for i in range(50):
        optimizer.zero_grad()
        output = gp_model(train_x)
        loss = -gp_model.marginal_log_likelihood(output, train_y)
        loss.backward()
        optimizer.n_iter += 1
        optimizer.step()

    # Test the model
    gp_model.eval()
    gp_model.condition(train_x, train_y)
    test_preds = gp_model(test_x).mean()
    mean_abs_error = torch.mean(torch.abs(test_y - test_preds))

    # The spectral mixture kernel should be trivially able to extrapolate the sine function.
    assert(mean_abs_error.data.squeeze()[0] < 0.05)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号