def test_simple(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
def f(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_onnx(trace)
torch._C._jit_pass_lint(trace)
self.assertExpected(str(trace))
评论列表
文章目录