def test_assign_traces(self):
"""Check that output Variables are assigned traces before they are saved."""
@traceable
class MyFn(Function):
@staticmethod
def forward(ctx, a):
out = a * 2
ctx.save_for_backward(out)
return out
@staticmethod
def backward(ctx, grad_a):
a, = ctx.saved_variables
return a * grad_a
x = Variable(torch.randn(10, 10), requires_grad=True)
trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
out.sum().backward()
torch._C._jit_pass_dce(trace)
self.assertExpected(str(trace))
评论列表
文章目录