function_factory_test.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def test_backward_inv_mv():
    a = torch.Tensor([
        [5, -3, 0],
        [-3, 5, 0],
        [0, 0, 2],
    ])
    b = torch.ones(3, 3).fill_(2)
    c = torch.randn(3)
    actual_a_grad = -torch.ger(a.inverse().mul_(0.5).mv(torch.ones(3)), a.inverse().mul_(0.5).mv(c)) * 2 * 2
    actual_c_grad = (a.inverse() / 2).t().mv(torch.ones(3)) * 2

    a_var = Variable(a, requires_grad=True)
    c_var = Variable(c, requires_grad=True)
    out_var = a_var.mul(Variable(b))
    out_var = gpytorch.inv_matmul(out_var, c_var)
    out_var = out_var.sum() * 2
    out_var.backward()
    a_res = a_var.grad.data
    c_res = c_var.grad.data

    assert(torch.norm(actual_a_grad - a_res) < 1e-4)
    assert(torch.norm(actual_c_grad - c_res) < 1e-4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号