def test_legacy_function_none_grad(self):
class MyFunction(Function):
def forward(self, x):
return torch.zeros(2, 2, 2)
def backward(self, grad_output):
return None
shape = (2, 3)
v = Variable(torch.ones(shape), requires_grad=True)
y = v[0, 0].expand(3, 5).t().sum()
MyFunction()(y).sum().backward()
self.assertEqual(v.grad.data, torch.zeros(shape))
评论列表
文章目录