def test_reuse_function(self):
@torch.jit.compile(nderivs=0)
def clinear(*args):
return F.linear(*args)
def cast(x):
return x
input = Variable(cast(torch.randn(1, 1)))
weights = Variable(cast(torch.randn(1, 1)))
bias = Variable(cast(torch.randn(1, 1)))
# linear AKA addmm without bias is of particular interest
# because we allocate a zero-filled new variable when we execute,
# and then *fill* it with the result
r1_ = clinear(input, weights)
with self.assertCompiled(clinear):
r1 = clinear(r1_, weights)
r2 = F.linear(F.linear(input, weights), weights)
self.assertEqual(r1, r2)
评论列表
文章目录