def test_requires_grad(self):
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5, 5))
z = Variable(torch.randn(5, 5), requires_grad=True)
a = x + y
self.assertFalse(a.requires_grad)
b = a + z
self.assertTrue(b.requires_grad)
def error():
raise RuntimeError
# Make sure backward isn't called on these
a._backward_hooks = OrderedDict()
x._backward_hooks = OrderedDict()
y._backward_hooks = OrderedDict()
a._backward_hooks['test'] = error
x._backward_hooks['test'] = error
y._backward_hooks['test'] = error
b.backward(torch.ones(5, 5))
评论列表
文章目录