def test_cse(self):
x = Variable(torch.Tensor([0.4, 0.3]), requires_grad=True)
y = Variable(torch.Tensor([0.7, 0.5]), requires_grad=True)
trace = torch._C._tracer_enter((x, y), 0)
w = (x + y) * (x + y) * (x + y)
t = torch.tanh(w) + torch.tanh(w)
z = (x + y) * (x + y) * (x + y) + t
torch._C._tracer_exit((z,))
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_onnx(trace)
torch._C._jit_pass_lint(trace)
torch._C._jit_pass_cse(trace)
self.assertExpected(str(trace))
评论列表
文章目录