def test_potrf(self):
root = Variable(torch.tril(torch.rand(S, S)), requires_grad=True)
def run_test(upper):
def func(root):
x = torch.mm(root, root.t())
return torch.potrf(x, upper)
gradcheck(func, [root])
gradgradcheck(func, [root])
run_test(upper=True)
run_test(upper=False)
评论列表
文章目录