def test_mm(self):
def test_shape(di, dj, dk):
x, _, _ = self._gen_sparse(2, 20, [di, dj])
t = torch.randn(di, dk)
y = torch.randn(dj, dk)
alpha = random.random()
beta = random.random()
res = torch.addmm(alpha, t, beta, x, y)
expected = torch.addmm(alpha, t, beta, x.to_dense(), y)
self.assertEqual(res, expected)
res = torch.addmm(t, x, y)
expected = torch.addmm(t, x.to_dense(), y)
self.assertEqual(res, expected)
res = torch.mm(x, y)
expected = torch.mm(x.to_dense(), y)
self.assertEqual(res, expected)
test_shape(10, 100, 100)
test_shape(100, 1000, 200)
test_shape(64, 10000, 300)
评论列表
文章目录