test_scattering.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:pyscatwave 作者: edouardoyallon 项目源码 文件源码
def testModulus(self):
        for jit in [True, False]:
            modulus = sl.Modulus(jit=jit)
            x = torch.cuda.FloatTensor(100,10,4,2).copy_(torch.rand(100,10,4,2))
            y = modulus(x)
            u = torch.squeeze(torch.sqrt(torch.sum(x * x, 3)))
            v = y.narrow(3, 0, 1)

            self.assertLess((u - v).abs().max(), 1e-6)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号