test_autograd.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_function_returns_input(self):
        class MyFunction(Function):
            @staticmethod
            def forward(ctx, x):
                return x

            @staticmethod
            def backward(ctx, grad):
                return grad * 2

        v = Variable(torch.ones(1), requires_grad=True)
        MyFunction.apply(v).backward()
        self.assertEqual(v.grad.data.tolist(), [2])

        v.grad.data.zero_()
        MyFunction.apply(v.clone()).backward()
        self.assertEqual(v.grad.data.tolist(), [2])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号