def test_normal_trace_log_det_quad_form_forward():
covar = torch.Tensor([
[5, -3, 0],
[-3, 5, 0],
[0, 0, 2],
])
mu_diffs = torch.Tensor([0, -1, 1])
chol_covar = torch.Tensor([
[1, -2, 0],
[0, 1, -2],
[0, 0, 1],
])
actual = mu_diffs.dot(covar.inverse().matmul(mu_diffs))
actual += math.log(np.linalg.det(covar.numpy()))
actual += (covar.inverse().matmul(chol_covar.t().matmul(chol_covar))).trace()
covarvar = Variable(covar)
chol_covarvar = Variable(chol_covar)
mu_diffsvar = Variable(mu_diffs)
res = gpytorch.trace_logdet_quad_form(mu_diffsvar, chol_covarvar, covarvar)
assert(all(torch.abs(actual - res.data).div(res.data) < 0.1))
评论列表
文章目录