def test_function_returns_input(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, grad):
return grad * 2
v = Variable(torch.ones(1), requires_grad=True)
MyFunction.apply(v).backward()
self.assertEqual(v.grad.data.tolist(), [2])
v.grad.data.zero_()
MyFunction.apply(v.clone()).backward()
self.assertEqual(v.grad.data.tolist(), [2])
评论列表
文章目录