def test_addbmm(self):
# num_batches = 10
# M, N, O = 12, 8, 5
num_batches = 2
M, N, O = 2, 3, 4
b1 = torch.randn(num_batches, M, N)
b2 = torch.randn(num_batches, N, O)
res = torch.bmm(b1, b2)
res2 = torch.Tensor().resize_as_(res[0]).zero_()
res2.addbmm_(b1, b2)
self.assertEqual(res2, res.sum(0, False))
res2.addbmm_(1, b1, b2)
self.assertEqual(res2, res.sum(0, False) * 2)
res2.addbmm_(1., .5, b1, b2)
self.assertEqual(res2, res.sum(0, False) * 2.5)
res3 = torch.addbmm(1, res2, 0, b1, b2)
self.assertEqual(res3, res2)
res4 = torch.addbmm(1, res2, .5, b1, b2)
self.assertEqual(res4, res.sum(0, False) * 3)
res5 = torch.addbmm(0, res2, 1, b1, b2)
self.assertEqual(res5, res.sum(0, False))
res6 = torch.addbmm(.1, res2, .5, b1, b2)
self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5))
评论列表
文章目录