def test_bilinear(self):
module = nn.Bilinear(10, 10, 8)
module2 = legacy.Bilinear(10, 10, 8)
module2.weight.copy_(module.weight.data)
module2.bias.copy_(module.bias.data)
input1 = torch.randn(4, 10)
input2 = torch.randn(4, 10)
output = module(Variable(input1), Variable(input2))
output2 = module2.forward([input1, input2])
input1_1 = Variable(input1, requires_grad=True)
input2_1 = Variable(input2, requires_grad=True)
output3 = module(input1_1, input2_1)
grad = torch.randn(*output3.size())
output3.backward(grad)
gi1 = input1_1.grad.data.clone()
gi2 = input2_1.grad.data.clone()
self.assertEqual(output.data, output2)
# TODO: this assertion is incorrect, fix needed
# self.assertEqual([gi1, gi2], output3)
self.assertTrue(gradcheck(lambda x1, x2: F.bilinear(x1, x2, module.weight, module.bias), (input1_1, input2_1)))
评论列表
文章目录