test_autograd.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
                          f_args_variable, f_args_tensor):
    output_variable = apply_fn(*f_args_variable)
    if not exclude_tensor_method(name, test_name):
        output_tensor = apply_fn(*f_args_tensor)
        if not torch.is_tensor(output_tensor) and not isinstance(output_tensor, tuple):
            output_tensor = torch.DoubleTensor((output_tensor,))
        test_case.assertEqual(unpack_variables(output_variable), output_tensor)

    if run_grad_checks:
        run_grad_and_gradgrad_checks(test_case, test_name, apply_fn,
                                     output_variable, f_args_variable)

    self_variable = f_args_variable[0]
    if isinstance(output_variable, torch.autograd.Variable) and self_variable is not None:
        output_variable.backward(torch.randn(*output_variable.size()).type_as(output_variable.data))
        test_case.assertTrue(type(self_variable.data) == type(self_variable.grad.data))
        test_case.assertTrue(self_variable.size() == self_variable.grad.size())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号