def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
f_args_variable, f_args_tensor):
output_variable = apply_fn(*f_args_variable)
if not exclude_tensor_method(name, test_name):
output_tensor = apply_fn(*f_args_tensor)
if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple):
output_tensor = torch.DoubleTensor((output_tensor,))
test_case.assertEqual(unpack_variables(output_variable), output_tensor)
if run_grad_checks:
run_grad_and_gradgrad_checks(test_case, test_name, apply_fn,
output_variable, f_args_variable)
self_variable = f_args_variable[0]
if isinstance(output_variable, torch.autograd.Variable) and self_variable is not None:
output_variable.backward(torch.randn(*output_variable.size()).type_as(output_variable.data))
test_case.assertTrue(type(self_variable.data) == type(self_variable.grad.data))
test_case.assertTrue(self_variable.size() == self_variable.grad.size())
评论列表
文章目录