function_factory_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_normal_trace_log_det_quad_form_forward():
    covar = torch.Tensor([
        [5, -3, 0],
        [-3, 5, 0],
        [0, 0, 2],
    ])
    mu_diffs = torch.Tensor([0, -1, 1])
    chol_covar = torch.Tensor([
        [1, -2, 0],
        [0, 1, -2],
        [0, 0, 1],
    ])

    actual = mu_diffs.dot(covar.inverse().matmul(mu_diffs))
    actual += math.log(np.linalg.det(covar.numpy()))
    actual += (covar.inverse().matmul(chol_covar.t().matmul(chol_covar))).trace()

    covarvar = Variable(covar)
    chol_covarvar = Variable(chol_covar)
    mu_diffsvar = Variable(mu_diffs)

    res = gpytorch.trace_logdet_quad_form(mu_diffsvar, chol_covarvar, covarvar)
    assert(all(torch.abs(actual - res.data).div(res.data) < 0.1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号