def test_baddbmm(self):
num_batches = 10
M, N, O = 12, 8, 5
b1 = torch.randn(num_batches, M, N)
b2 = torch.randn(num_batches, N, O)
res = torch.bmm(b1, b2)
res2 = torch.Tensor().resize_as_(res).zero_()
res2.baddbmm_(b1,b2)
self.assertEqual(res2, res)
res2.baddbmm_(1,b1,b2)
self.assertEqual(res2, res*2)
res2.baddbmm_(1,.5,b1,b2)
self.assertEqual(res2, res*2.5)
res3 = torch.baddbmm(1,res2,0,b1,b2)
self.assertEqual(res3, res2)
res4 = torch.baddbmm(1,res2,.5,b1,b2)
self.assertEqual(res4, res*3)
res5 = torch.baddbmm(0,res2,1,b1,b2)
self.assertEqual(res5, res)
res6 = torch.baddbmm(.1,res2,.5,b1,b2)
self.assertEqual(res6, res2 * .1 + res * .5)
评论列表
文章目录