def test_trace_log_det_quad_form():
mu_diffs_var = Variable(torch.arange(1, 5, 1))
chol_covar_1_var = Variable(torch.eye(4))
# Test case
c1_var = Variable(torch.Tensor([5, 1, 2, 0]), requires_grad=True)
c2_var = Variable(torch.Tensor([[6, 0], [1, -1]]), requires_grad=True)
c3_var = Variable(torch.Tensor([7, 2, 1, 0]), requires_grad=True)
diag_var = Variable(torch.Tensor([1]), requires_grad=True)
diag_var_expand = diag_var.expand(4)
toeplitz_1 = ToeplitzLazyVariable(c1_var).evaluate()
kronecker_product = KroneckerProductLazyVariable(c2_var).evaluate()
toeplitz_2 = ToeplitzLazyVariable(c3_var).evaluate()
actual = toeplitz_1 * kronecker_product * toeplitz_2 + diag_var_expand.diag()
# Actual case
mul_lv, diag = make_mul_lazy_var()
t1, t2, t3 = mul_lv.lazy_vars
# Test forward
tldqf_res = mul_lv.trace_log_det_quad_form(mu_diffs_var, chol_covar_1_var)
tldqf_actual = gpytorch._trace_logdet_quad_form_factory_class()(mu_diffs_var, chol_covar_1_var, actual)
assert(math.fabs(tldqf_res.data.squeeze()[0] - tldqf_actual.data.squeeze()[0]) < 1.5)
# Test backwards
tldqf_res.backward()
tldqf_actual.backward()
assert((c1_var.grad.data - t1.column.grad.data).abs().norm() / c1_var.grad.data.abs().norm() < 1e-1)
assert((c2_var.grad.data - t2.columns.grad.data).abs().norm() / c2_var.grad.data.abs().norm() < 1e-1)
assert((c3_var.grad.data - t3.column.grad.data).abs().norm() / c3_var.grad.data.abs().norm() < 1e-1)
assert((diag_var.grad.data - diag.grad.data).abs().norm() / diag_var.grad.data.abs().norm() < 1e-1)
评论列表
文章目录