kissgp_gp_regression_test.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_kissgp_gp_mean_abs_error_cuda():
    if torch.cuda.is_available():
        train_x, train_y, test_x, test_y = make_data(cuda=True)
        gp_model = GPRegressionModel().cuda()

        # Optimize the model
        gp_model.train()
        optimizer = optim.Adam(gp_model.parameters(), lr=0.1)
        optimizer.n_iter = 0
        for i in range(25):
            optimizer.zero_grad()
            output = gp_model(train_x)
            loss = -gp_model.marginal_log_likelihood(output, train_y)
            loss.backward()
            optimizer.n_iter += 1
            optimizer.step()

        # Test the model
        gp_model.eval()
        gp_model.condition(train_x, train_y)
        test_preds = gp_model(test_x).mean()
        mean_abs_error = torch.mean(torch.abs(test_y - test_preds))

        assert(mean_abs_error.data.squeeze()[0] < 0.02)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号