def test_function(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, scalar, tensor2):
ctx.scalar = scalar
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + scalar * tensor2 + tensor1 * tensor2
@staticmethod
def backward(ctx, grad_output):
var1, var2 = ctx.saved_variables
# NOTE: self is the test case here
self.assertIsInstance(var1, Variable)
self.assertIsInstance(var2, Variable)
self.assertIsInstance(grad_output, Variable)
return (grad_output + grad_output * var2, None,
grad_output * ctx.scalar + grad_output * var1)
x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(
x_grad_desc,
'Identity(AddBackward(ExpandBackward(AccumulateGrad()), '
'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
self.assertEqual(
y_grad_desc,
'Identity(AddBackward(MulConstantBackward(ExpandBackward(AccumulateGrad())), '
'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
评论列表
文章目录