def testModulus(self):
for jit in [True, False]:
modulus = sl.Modulus(jit=jit)
x = torch.cuda.FloatTensor(100,10,4,2).copy_(torch.rand(100,10,4,2))
y = modulus(x)
u = torch.squeeze(torch.sqrt(torch.sum(x * x, 3)))
v = y.narrow(3, 0, 1)
self.assertLess((u - v).abs().max(), 1e-6)
评论列表
文章目录