test_nn.py 文件源码

python
阅读 49 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def test_Conv2d_groups_nobias(self):
        m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False)
        i = Variable(torch.randn(2, 4, 6, 6), requires_grad=True)
        output = m(i)
        grad_output = torch.randn(2, 4, 4, 4)
        output.backward(grad_output)

        m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False)
        m1.weight.data.copy_(m.weight.data[:2])
        i1 = Variable(i.data[:, :2].contiguous(), requires_grad=True)
        output1 = m1(i1)
        output1.backward(grad_output[:, :2].contiguous())

        m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False)
        m2.weight.data.copy_(m.weight.data[2:])
        i2 = Variable(i.data[:, 2:].contiguous(), requires_grad=True)
        output2 = m2(i2)
        output2.backward(grad_output[:, 2:].contiguous())

        self.assertEqual(output, torch.cat([output1, output2], 1))
        self.assertEqual(i.grad.data,
                         torch.cat([i1.grad.data, i2.grad.data], 1))
        self.assertEqual(m.weight.grad.data,
                         torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号