test_modules.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:fast-wavenet.pytorch 作者: dhpollack 项目源码 文件源码
def test_dilate(self):
        input = Variable(torch.arange(0, 13).view(1, 1, 13))

        dilated, _ = dilate(input, 1)
        self.assertEqual(dilated.size(), (1, 1, 13))
        self.assertEqual(dilated[0, 0, 4].data[0], 4)

        dilated, _ = dilate(input, 2)
        self.assertEqual(dilated.size(), (2, 1, 7))
        self.assertEqual(dilated[1, 0, 2].data[0], 4)

        dilated, _ = dilate(input, 4)
        self.assertEqual(dilated.size(), (4, 1, 4))
        self.assertEqual(dilated[3, 0, 1].data[0], 4)

        dilated, _ = dilate(dilated, 1)
        self.assertEqual(dilated.size(), (1, 1, 16))
        self.assertEqual(dilated[0, 0, 7].data[0], 4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号