test_nn.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def _assertGradAndGradgradChecks(test_case, apply_fn, inputs):
    # call assert function rather than returning a bool since it's nicer
    # if we get whether this failed on the gradcheck or the gradgradcheck.
    test_case.assertTrue(gradcheck(apply_fn, inputs))
    dummy_out = apply_fn(*inputs)

    def randn_match_cpu_gpu(x):
        a = torch.randn(x.size())
        if x.is_cuda:
            a = a.cuda(x.get_device())
        return a

    if isinstance(dummy_out, tuple):
        grad_y = tuple(Variable(randn_match_cpu_gpu(x), requires_grad=x.requires_grad)
                       for x in dummy_out if isinstance(x, Variable))
    else:
        grad_y = (Variable(randn_match_cpu_gpu(dummy_out), requires_grad=dummy_out.requires_grad),)

    test_case.assertTrue(gradgradcheck(apply_fn, inputs, grad_y,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号