mul_lazy_variable_test.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_trace_log_det_quad_form():
    mu_diffs_var = Variable(torch.arange(1, 5, 1))
    chol_covar_1_var = Variable(torch.eye(4))

    # Test case
    c1_var = Variable(torch.Tensor([5, 1, 2, 0]), requires_grad=True)
    c2_var = Variable(torch.Tensor([[6, 0], [1, -1]]), requires_grad=True)
    c3_var = Variable(torch.Tensor([7, 2, 1, 0]), requires_grad=True)
    diag_var = Variable(torch.Tensor([1]), requires_grad=True)
    diag_var_expand = diag_var.expand(4)
    toeplitz_1 = ToeplitzLazyVariable(c1_var).evaluate()
    kronecker_product = KroneckerProductLazyVariable(c2_var).evaluate()
    toeplitz_2 = ToeplitzLazyVariable(c3_var).evaluate()
    actual = toeplitz_1 * kronecker_product * toeplitz_2 + diag_var_expand.diag()

    # Actual case
    mul_lv, diag = make_mul_lazy_var()
    t1, t2, t3 = mul_lv.lazy_vars

    # Test forward
    tldqf_res = mul_lv.trace_log_det_quad_form(mu_diffs_var, chol_covar_1_var)
    tldqf_actual = gpytorch._trace_logdet_quad_form_factory_class()(mu_diffs_var, chol_covar_1_var, actual)
    assert(math.fabs(tldqf_res.data.squeeze()[0] - tldqf_actual.data.squeeze()[0]) < 1.5)

    # Test backwards
    tldqf_res.backward()
    tldqf_actual.backward()
    assert((c1_var.grad.data - t1.column.grad.data).abs().norm() / c1_var.grad.data.abs().norm() < 1e-1)
    assert((c2_var.grad.data - t2.columns.grad.data).abs().norm() / c2_var.grad.data.abs().norm() < 1e-1)
    assert((c3_var.grad.data - t3.column.grad.data).abs().norm() / c3_var.grad.data.abs().norm() < 1e-1)
    assert((diag_var.grad.data - diag.grad.data).abs().norm() / diag_var.grad.data.abs().norm() < 1e-1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号