kissgp_additive_classification_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_kissgp_classification_error():
    gpytorch.functions.use_toeplitz = False
    model = GPClassificationModel()

    # Find optimal model hyperparameters
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.15)
    optimizer.n_iter = 0
    for i in range(100):
        optimizer.zero_grad()
        output = model.forward(train_x)
        loss = -model.marginal_log_likelihood(output, train_y)
        loss.backward()
        optimizer.n_iter += 1
        optimizer.step()

    # Set back to eval mode
    model.eval()
    test_preds = model(train_x).mean().ge(0.5).float().mul(2).sub(1).squeeze()
    mean_abs_error = torch.mean(torch.abs(train_y - test_preds) / 2)
    gpytorch.functions.use_toeplitz = True
    assert(mean_abs_error.data.squeeze()[0] < 5e-2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号