def test_once_differentiable(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, scalar, tensor2):
ctx.scalar = scalar
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + scalar * tensor2 + tensor1 * tensor2
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
t1, t2 = ctx.saved_tensors
# NOTE: self is the test case here
self.assertTrue(torch.is_tensor(t1))
self.assertTrue(torch.is_tensor(t2))
self.assertTrue(torch.is_tensor(grad_output))
return (grad_output + grad_output * t2, None,
grad_output * ctx.scalar + grad_output * t1)
x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(graph_desc(x.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
self.assertEqual(graph_desc(y.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
评论列表
文章目录