def test_disabled_traced_function(self):
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.Tensor([0.7]), requires_grad=True)
@torch.jit.compile(enabled=False)
def doit(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
z = doit(x, y)
z2 = doit(x, y)
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
self.assertEqual(z, z2)
评论列表
文章目录