test_modules.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:fast-wavenet.pytorch 作者: dhpollack 项目源码 文件源码
def test_net_forward(self):

        model = Net()
        print(model)
        self.assertEqual(model.conv1.out_channels, model.conv2.out_channels)
        self.assertEqual(model.conv1.out_channels, model.conv3.in_channels)
        self.assertEqual(model.conv2.out_channels, model.conv3.in_channels)
        self.assertEqual(model.conv3.out_channels, model.conv4.in_channels)

        # simple forward pass
        input = Variable(torch.rand(1, 1, 4) * 2 - 1)
        output = model(input)
        self.assertEqual(output.size(), (1, 2, 4))

        # feature split
        model.conv1.split_feature(feature_i=1)
        model.conv2.split_feature(feature_i=3)
        print(model)
        self.assertEqual(model.conv1.out_channels, model.conv2.out_channels)
        self.assertEqual(model.conv1.out_channels, model.conv3.in_channels)
        self.assertEqual(model.conv2.out_channels, model.conv3.in_channels)
        self.assertEqual(model.conv3.out_channels, model.conv4.in_channels)

        output2 = model(input)

        diff = output - output2

        dot = torch.dot(diff.view(-1), diff.view(-1))
        # should be close to 0
        #self.assertTrue(np.isclose(dot.data[0], 0., atol=1e-2))
        print("mse: ", dot.data[0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号