def _assertGradAndGradgradChecks(test_case, apply_fn, inputs):
# call assert function rather than returning a bool since it's nicer
# if we get whether this failed on the gradcheck or the gradgradcheck.
test_case.assertTrue(gradcheck(apply_fn, inputs))
dummy_out = apply_fn(*inputs)
def randn_match_cpu_gpu(x):
a = torch.randn(x.size())
if x.is_cuda:
a = a.cuda(x.get_device())
return a
if isinstance(dummy_out, tuple):
grad_y = tuple(Variable(randn_match_cpu_gpu(x), requires_grad=x.requires_grad)
for x in dummy_out if isinstance(x, Variable))
else:
grad_y = (Variable(randn_match_cpu_gpu(dummy_out), requires_grad=dummy_out.requires_grad),)
test_case.assertTrue(gradgradcheck(apply_fn, inputs, grad_y,))
评论列表
文章目录