def test_reentrant(self):
y_data = torch.randn(2, 2)
class Reenter(Function):
@staticmethod
def forward(ctx, x_data):
ctx.x = Variable(x_data, requires_grad=True)
ctx.y = Variable(y_data, requires_grad=True)
ctx.output_var = ctx.x * ctx.y
return ctx.output_var.data
@staticmethod
def backward(ctx, grad_output):
ctx.output_var.sum().backward()
return ctx.x.grad * grad_output
x = Variable(torch.randn(2, 2), requires_grad=True)
out = Reenter.apply(x)
out.sum().backward(create_graph=True)
self.assertEqual(x.grad.data, y_data)
评论列表
文章目录