def test_forward_inv_mm():
for n_cols in [2, 3, 4]:
a = torch.Tensor([
[5, -3, 0],
[-3, 5, 0],
[0, 0, 2],
])
b = torch.randn(3, n_cols)
actual = a.inverse().mm(b)
a_var = Variable(a)
b_var = Variable(b)
out_var = gpytorch.inv_matmul(a_var, b_var)
res = out_var.data
assert(torch.norm(actual - res) < 1e-4)
评论列表
文章目录