def test_save_none_for_backward(self):
test_case = self
class MyFn(Function):
def forward(self, input):
self.save_for_backward(None, input, None)
return input * input
def backward(self, grad_output):
n1, input, n2 = self.saved_tensors
test_case.assertIsNone(n1)
test_case.assertIsNone(n2)
return 2 * input * grad_output
x = Variable(torch.randn(5, 5), requires_grad=True)
y = MyFn()(x)
y.sum().backward()
self.assertEqual(x.grad.data, 2 * x.data)
python类Function()的实例源码
def test_mark_non_differentiable(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, input):
output = input > 0
ctx.mark_non_differentiable(output)
return output
@staticmethod
def backward(ctx, grad_output):
return (grad_output * 0).type(torch.DoubleTensor)
x = Variable(torch.randn(5, 5), requires_grad=True)
mask = MyFunction.apply(x)
self.assertFalse(mask.requires_grad)
y = x.masked_fill(mask, 0)
y.sum().backward()
def test_save_none_for_backward(self):
test_case = self
class MyFn(Function):
def forward(self, input):
self.save_for_backward(None, input, None)
return input * input
def backward(self, grad_output):
n1, input, n2 = self.saved_tensors
test_case.assertIsNone(n1)
test_case.assertIsNone(n2)
return 2 * input * grad_output
x = Variable(torch.randn(5, 5), requires_grad=True)
y = MyFn()(x)
y.sum().backward()
self.assertEqual(x.grad.data, 2 * x.data)
def test_assign_traces(self):
"""Check that output Variables are assigned traces before they are saved."""
@traceable
class MyFn(Function):
@staticmethod
def forward(ctx, a):
out = a * 2
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad_a):
a, = ctx.saved_variables
return a * grad_a
x = Variable(torch.randn(10, 10), requires_grad=True)
trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
out.sum().backward()
torch._C._jit_pass_dce(trace)
self.assertExpected(str(trace))
def test_mark_non_differentiable(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, input):
output = input > 0
ctx.mark_non_differentiable(output)
return output
@staticmethod
def backward(ctx, grad_output):
return (grad_output * 0).type(torch.DoubleTensor)
x = Variable(torch.randn(5, 5), requires_grad=True)
mask = MyFunction.apply(x)
self.assertFalse(mask.requires_grad)
y = x.masked_fill(mask, 0)
y.sum().backward()
def test_backward_device(self):
# check that current device matches the variable's device
device = [None]
class Identity(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad_output):
device[0] = torch.cuda.current_device()
return grad_output.clone()
v = Variable(torch.randn(1).cuda(1), requires_grad=True)
Identity.apply(v).backward()
self.assertEqual(device[0], 1)
def test_reentrant(self):
y_data = torch.randn(2, 2)
class Reenter(Function):
@staticmethod
def forward(ctx, x_data):
ctx.x = Variable(x_data, requires_grad=True)
ctx.y = Variable(y_data, requires_grad=True)
ctx.output_var = ctx.x * ctx.y
return ctx.output_var.data
@staticmethod
def backward(ctx, grad_output):
ctx.output_var.sum().backward()
return ctx.x.grad * grad_output
x = Variable(torch.randn(2, 2), requires_grad=True)
out = Reenter.apply(x)
out.sum().backward()
self.assertEqual(x.grad.data, y_data)
def test_symbolic_mismatch(self):
class MyFun(Function):
@staticmethod
def symbolic(g, x):
# The inside of this function should never be invoked, because
# we will fail due to an argument mismatch first.
assert False
@staticmethod
def forward(ctx, x, y):
return x + y
x = Variable(torch.randn(2, 2).fill_(1.0))
y = Variable(torch.randn(2, 2).fill_(1.0))
with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
export_to_string(FuncModule(MyFun().apply), (x, y))
# TODO: Do an nn style test for these
def test_assign_traces(self):
"""Check that output Variables are assigned traces before they are saved."""
@traceable
class MyFn(Function):
@staticmethod
def forward(ctx, a):
out = a * 2
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad_a):
a, = ctx.saved_variables
return a * grad_a
x = Variable(torch.randn(10, 10), requires_grad=True)
trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
out.sum().backward()
torch._C._jit_pass_dce(trace)
self.assertExpectedTrace(trace)
def test_inplace_check(self):
class MyInplaceFn(Function):
@staticmethod
def forward(self, x):
x.add_(1)
self.mark_dirty(x)
return x
@staticmethod
def backward(self, grad):
return grad
@torch.jit.compile(nderivs=0)
def fn(x):
return MyInplaceFn.apply(x)
x = Variable(torch.randn(5, 5))
fn(x) # trace
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
fn(x)
def test_function_returns_input(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, grad):
return grad * 2
v = Variable(torch.ones(1), requires_grad=True)
MyFunction.apply(v).backward()
self.assertEqual(v.grad.data.tolist(), [2])
v.grad.data.zero_()
MyFunction.apply(v.clone()).backward()
self.assertEqual(v.grad.data.tolist(), [2])
def test_mark_non_differentiable_mixed(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, input):
a = input + 1
b = input + 2
ctx.mark_non_differentiable(a)
return a, b
@staticmethod
def backward(ctx, grad_a, grad_b):
self.assertTrue((grad_a == 0).all())
self.assertTrue((grad_b == 1).all())
return grad_b
x = Variable(torch.randn(5, 5), requires_grad=True)
a, b = MyFunction.apply(x)
self.assertFalse(a.requires_grad)
self.assertTrue(b.requires_grad)
b.sum().backward()
self.assertEqual(x.grad.data, torch.ones(5, 5))
def test_mark_non_differentiable_none(self):
# This used to segfault because MyFunction would send back null
# gradients to MulBackward, which is implemented in C++. C++
# implemented functions expect incoming grad_ouptuts to be non-null.
class MyFunction(Function):
@staticmethod
def forward(ctx, input):
output = input.clone()
ctx.mark_non_differentiable(output)
return output
@staticmethod
def backward(ctx, grad_output):
return None
x = Variable(torch.randn(5, 5), requires_grad=True)
r = MyFunction.apply(x * x)
(r * x).sum().backward()
def test_backward_device(self):
# check that current device matches the variable's device
device = [None]
class Identity(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad_output):
device[0] = torch.cuda.current_device()
return grad_output.clone()
v = Variable(torch.randn(1).cuda(1), requires_grad=True)
Identity.apply(v).backward()
self.assertEqual(device[0], 1)
def test_reentrant(self):
y_data = torch.randn(2, 2)
class Reenter(Function):
@staticmethod
def forward(ctx, x_data):
ctx.x = Variable(x_data, requires_grad=True)
ctx.y = Variable(y_data, requires_grad=True)
ctx.output_var = ctx.x * ctx.y
return ctx.output_var.data
@staticmethod
def backward(ctx, grad_output):
ctx.output_var.sum().backward()
return ctx.x.grad * grad_output
x = Variable(torch.randn(2, 2), requires_grad=True)
out = Reenter.apply(x)
out.sum().backward(create_graph=True)
self.assertEqual(x.grad.data, y_data)
def test_inplace_view_python(self):
# in-place modifications of Python-autograd created view
a = Variable(torch.randn(4, 4), requires_grad=True)
b = Variable(torch.randn(2, 2), requires_grad=True)
class PyAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.mark_dirty(x)
x.add_(y)
return x
@staticmethod
def backward(ctx, grad):
return grad, grad
def func(root, b):
x = root.clone()
PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b)
PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b)
return x
gradcheck(func, [a, b], raise_exception=True)
go = Variable(torch.randn(a.size()), requires_grad=True)
gradgradcheck(func, (a, b), (go,))
def test_function(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, scalar, tensor2):
ctx.scalar = scalar
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + scalar * tensor2 + tensor1 * tensor2
@staticmethod
def backward(ctx, grad_output):
var1, var2 = ctx.saved_variables
# NOTE: self is the test case here
self.assertIsInstance(var1, Variable)
self.assertIsInstance(var2, Variable)
self.assertIsInstance(grad_output, Variable)
return (grad_output + grad_output * var2, None,
grad_output * ctx.scalar + grad_output * var1)
x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(
x_grad_desc,
'Identity(AddBackward(ExpandBackward(AccumulateGrad()), '
'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
self.assertEqual(
y_grad_desc,
'Identity(AddBackward(MulConstantBackward(ExpandBackward(AccumulateGrad())), '
'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
def test_once_differentiable(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, scalar, tensor2):
ctx.scalar = scalar
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + scalar * tensor2 + tensor1 * tensor2
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
t1, t2 = ctx.saved_tensors
# NOTE: self is the test case here
self.assertTrue(torch.is_tensor(t1))
self.assertTrue(torch.is_tensor(t2))
self.assertTrue(torch.is_tensor(grad_output))
return (grad_output + grad_output * t2, None,
grad_output * ctx.scalar + grad_output * t1)
x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(graph_desc(x.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
self.assertEqual(graph_desc(y.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
def test_hook_none(self):
# WARNING: this is a test for autograd internals.
# You should never have to use such things in your code.
class NoneGradientFunction(Function):
def forward(self, x, y):
assert self.needs_input_grad[0]
assert not self.needs_input_grad[1]
return x, y
def backward(self, grad_x, grad_y):
return grad_x, None
fn = NoneGradientFunction()
was_called = [False]
def hook(grad_input, grad_output):
self.assertIsInstance(grad_input, tuple)
self.assertIsInstance(grad_output, tuple)
self.assertIsNotNone(grad_input[0])
self.assertIsNone(grad_input[1])
self.assertIsNotNone(grad_output[0])
self.assertIsNotNone(grad_output[1])
was_called[0] = True
fn.register_hook(hook)
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5))
sum(fn(x, y)).sum().backward()
self.assertTrue(was_called[0])
def test_gc_in_destructor(self):
"""
Previously, if a Function destructor triggered a garbage collection,
the Variable's tp_dealloc handler would get called twice leading to a
segfault.
"""
class CollectOnDelete(Function):
def __del__(self):
gc.collect()
for i in range(10):
Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
def test_too_many_grads(self):
class MyFn(Function):
def forward(self, input):
return input
def backward(self, grad_output):
return grad_output, None, None
x = Variable(torch.randn(5, 5), requires_grad=True)
y = MyFn()(x)
y.sum().backward()
self.assertEqual(x.grad.data, x.data.clone().fill_(1))
def test_dep_nograd(self):
class F1(Function):
def forward(self, input):
out = torch.randn(input.size())
self.mark_non_differentiable(out)
return input, out
def backward(self, grad_output, ignored):
return grad_output
class F2(Function):
def forward(self, input, ignored):
return input
def backward(self, grad_output):
return grad_output, None
x = Variable(torch.randn(5), requires_grad=True)
a, b = F1()(x)
b = b + 1 # separate F1 from F2 by another op
self.assertTrue(a.requires_grad)
self.assertFalse(b.requires_grad)
c = F2()(a, b)
c.backward(torch.ones(c.size()))
self.assertEqual(x.grad.data, torch.ones(x.size()))
def test_function(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, scalar, tensor2):
ctx.scalar = scalar
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + scalar * tensor2 + tensor1 * tensor2
@staticmethod
def backward(ctx, grad_output):
var1, var2 = ctx.saved_variables
# NOTE: self is the test case here
self.assertIsInstance(var1, Variable)
self.assertIsInstance(var2, Variable)
self.assertIsInstance(grad_output, Variable)
return (grad_output + grad_output * var2, None,
grad_output * ctx.scalar + grad_output * var1)
x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(
x_grad_desc,
'Identity(AddBackward(ExpandBackward(AccumulateGrad()), '
'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
self.assertEqual(
y_grad_desc,
'Identity(AddBackward(MulConstantBackward(ExpandBackward(AccumulateGrad())), '
'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
def test_once_differentiable(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, scalar, tensor2):
ctx.scalar = scalar
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + scalar * tensor2 + tensor1 * tensor2
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
t1, t2 = ctx.saved_tensors
# NOTE: self is the test case here
self.assertTrue(torch.is_tensor(t1))
self.assertTrue(torch.is_tensor(t2))
self.assertTrue(torch.is_tensor(grad_output))
return (grad_output + grad_output * t2, None,
grad_output * ctx.scalar + grad_output * t1)
x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(graph_desc(x.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
self.assertEqual(graph_desc(y.grad.grad_fn),
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
def test_hook_none(self):
# WARNING: this is a test for autograd internals.
# You should never have to use such things in your code.
class NoneGradientFunction(Function):
def forward(self, x, y):
assert self.needs_input_grad[0]
assert not self.needs_input_grad[1]
return x, y
def backward(self, grad_x, grad_y):
return grad_x, None
fn = NoneGradientFunction()
was_called = [False]
def hook(grad_input, grad_output):
self.assertIsInstance(grad_input, tuple)
self.assertIsInstance(grad_output, tuple)
self.assertIsNotNone(grad_input[0])
self.assertIsNone(grad_input[1])
self.assertIsNotNone(grad_output[0])
self.assertIsNotNone(grad_output[1])
was_called[0] = True
fn.register_hook(hook)
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5))
sum(fn(x, y)).sum().backward()
self.assertTrue(was_called[0])
def test_return_leaf(self):
class Identity(Function):
def forward(self, a, b):
return a, a + b
def backward(self, grad_a, grad_b):
return grad_a + grad_b, grad_b
class Inplace(InplaceFunction):
def forward(self, a, b):
self.mark_dirty(a)
return a.add_(b), b + 2
def backward(self, grad_a, grad_b):
return grad_a, grad_a + grad_b
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
q, p = Identity()(x, y)
# Make sure hooks only receive grad from usage of q, not x.
q.register_hook(
lambda grad: self.assertEqual(grad.data, torch.ones(5, 5)))
(q + p + x).sum().backward()
self.assertEqual(x.grad.data, torch.ones(5, 5) * 3)
self.assertEqual(y.grad.data, torch.ones(5, 5))
del q, p # these need to be freed, or next part will raise an error
def test_too_many_grads(self):
class MyFn(Function):
def forward(self, input):
return input
def backward(self, grad_output):
return grad_output, None, None
x = Variable(torch.randn(5, 5), requires_grad=True)
y = MyFn()(x)
y.sum().backward()
self.assertEqual(x.grad.data, x.data.clone().fill_(1))
def test_dep_nograd(self):
class F1(Function):
def forward(self, input):
out = torch.randn(input.size())
self.mark_non_differentiable(out)
return input, out
def backward(self, grad_output, ignored):
return grad_output
class F2(Function):
def forward(self, input, ignored):
return input
def backward(self, grad_output):
return grad_output, None
x = Variable(torch.randn(5), requires_grad=True)
a, b = F1()(x)
b = b + 1 # separate F1 from F2 by another op
self.assertTrue(a.requires_grad)
self.assertFalse(b.requires_grad)
c = F2()(a, b)
c.backward(torch.ones(c.size()))
self.assertEqual(x.grad.data, torch.ones(x.size()))
def test_legacy_fail(self):
class MyLegacyFn(Function):
def forward(self, x):
return x
def backward(self, grad_output):
return grad_output
x = Variable(torch.Tensor([0]), requires_grad=True)
trace = torch._C._tracer_enter((x,), 0)
self.assertRaisesRegex(RuntimeError, "MyLegacyFn", lambda: MyLegacyFn()(x))
torch._C._tracer_exit((x,))
def test_function(self):
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, scalar, tensor2):
ctx.scalar = scalar
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + scalar * tensor2 + tensor1 * tensor2
@staticmethod
def backward(ctx, grad_output):
var1, var2 = ctx.saved_variables
# NOTE: self is the test case here
self.assertIsInstance(var1, Variable)
self.assertIsInstance(var2, Variable)
self.assertIsInstance(grad_output, Variable)
return (grad_output + grad_output * var2, None,
grad_output * ctx.scalar + grad_output * var1)
x, y = self._function_test(MyFunction)
x_grad_desc = graph_desc(x.grad.grad_fn)
y_grad_desc = graph_desc(y.grad.grad_fn)
self.assertEqual(
x_grad_desc,
'Identity(AddBackward(ExpandBackward(AccumulateGrad()), '
'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')
self.assertEqual(
y_grad_desc,
'Identity(AddBackward(MulConstantBackward(ExpandBackward(AccumulateGrad())), '
'MulBackward(ExpandBackward(AccumulateGrad()), AccumulateGrad())))')