def test_return_leaf(self):
class Identity(Function):
def forward(self, a, b):
return a, a + b
def backward(self, grad_a, grad_b):
return grad_a + grad_b, grad_b
class Inplace(InplaceFunction):
def forward(self, a, b):
self.mark_dirty(a)
return a.add_(b), b + 2
def backward(self, grad_a, grad_b):
return grad_a, grad_a + grad_b
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
q, p = Identity()(x, y)
# Make sure hooks only receive grad from usage of q, not x.
q.register_hook(
lambda grad: self.assertEqual(grad.data, torch.ones(5, 5)))
(q + p + x).sum().backward()
self.assertEqual(x.grad.data, torch.ones(5, 5) * 3)
self.assertEqual(y.grad.data, torch.ones(5, 5))
del q, p # these need to be freed, or next part will raise an error
评论列表
文章目录