def test_kissgp_classification_error():
gpytorch.functions.use_toeplitz = False
model = GPClassificationModel()
# Find optimal model hyperparameters
model.train()
optimizer = optim.Adam(model.parameters(), lr=0.15)
optimizer.n_iter = 0
for i in range(100):
optimizer.zero_grad()
output = model.forward(train_x)
loss = -model.marginal_log_likelihood(output, train_y)
loss.backward()
optimizer.n_iter += 1
optimizer.step()
# Set back to eval mode
model.eval()
test_preds = model(train_x).mean().ge(0.5).float().mul(2).sub(1).squeeze()
mean_abs_error = torch.mean(torch.abs(train_y - test_preds) / 2)
gpytorch.functions.use_toeplitz = True
assert(mean_abs_error.data.squeeze()[0] < 5e-2)
kissgp_additive_classification_test.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录