test_autograd.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_reentrant(self):
        y_data = torch.randn(2, 2)

        class Reenter(Function):
            @staticmethod
            def forward(ctx, x_data):
                ctx.x = Variable(x_data, requires_grad=True)
                ctx.y = Variable(y_data, requires_grad=True)
                ctx.output_var = ctx.x * ctx.y
                return ctx.output_var.data

            @staticmethod
            def backward(ctx, grad_output):
                ctx.output_var.sum().backward()
                return ctx.x.grad * grad_output

        x = Variable(torch.randn(2, 2), requires_grad=True)
        out = Reenter.apply(x)
        out.sum().backward(create_graph=True)
        self.assertEqual(x.grad.data, y_data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号