test_nn.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def test_bilinear(self):
        module = nn.Bilinear(10, 10, 8)
        module_legacy = legacy.Bilinear(10, 10, 8)

        module_legacy.weight.copy_(module.weight.data)
        module_legacy.bias.copy_(module.bias.data)

        input1 = torch.randn(4, 10)
        input2 = torch.randn(4, 10)

        output = module(Variable(input1), Variable(input2))
        output_legacy = module_legacy.forward([input1, input2])

        self.assertEqual(output.data, output_legacy)

        input1_1 = Variable(input1, requires_grad=True)
        input2_1 = Variable(input2, requires_grad=True)

        module.zero_grad()
        module_legacy.zeroGradParameters()

        output = module(input1_1, input2_1)
        grad_output = torch.randn(*output.size())
        gi1_legacy, gi2_legacy = module_legacy.backward([input1, input2], grad_output)
        output.backward(grad_output)
        gi1 = input1_1.grad.data.clone()
        gi2 = input2_1.grad.data.clone()

        self.assertEqual(gi1, gi1_legacy)
        self.assertEqual(gi2, gi2_legacy)
        self.assertEqual(module.weight.grad.data, module_legacy.gradWeight)
        self.assertEqual(module.bias.grad.data, module_legacy.gradBias)

        _assertGradAndGradgradChecks(self, lambda x1, x2: F.bilinear(x1, x2, module.weight, module.bias),
                                     (input1_1, input2_1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号