test_autograd.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def _function_test(self, cls):
        x = Variable(torch.randn(5, 5), requires_grad=True)
        y = Variable(torch.randn(5, 5), requires_grad=True)
        result = cls.apply(x, 2, y)
        go = Variable(torch.ones(1), requires_grad=True)
        result.sum().backward(go)

        self.assertEqual(x.grad.data, y.data + torch.ones(5, 5))
        self.assertEqual(y.grad.data, x.data + torch.ones(5, 5) * 2)

        self.assertFalse(x.grad.volatile)
        self.assertFalse(y.grad.volatile)
        self.assertIsNotNone(x.grad.grad_fn)
        self.assertIsNotNone(y.grad.grad_fn)

        return x, y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号