def run_test_matmul_ab_ci8_shape(self, shape, k, transpose=False):
ashape_complex = shape[:-2] + (shape[-2], k * 2)
bshape_complex = shape[:-2] + (k, shape[-1] * 2)
a8 = (np.random.random(size=ashape_complex) * 255).astype(np.int8)
b8 = (np.random.random(size=bshape_complex) * 255).astype(np.int8)
a_gold = a8.astype(np.float32).view(np.complex64)
b_gold = b8.astype(np.float32).view(np.complex64)
if transpose:
a_gold, b_gold = H(b_gold), H(a_gold)
c_gold = np.matmul(a_gold, b_gold)
a = a8.view(bf.DataType.ci8)
b = b8.view(bf.DataType.ci8)
a = bf.asarray(a, space='cuda')
b = bf.asarray(b, space='cuda')
if transpose:
a, b = H(b), H(a)
c = bf.zeros_like(c_gold, space='cuda')
self.linalg.matmul(1, a, b, 0, c)
c = c.copy('system')
np.testing.assert_allclose(c, c_gold, RTOL, ATOL)
评论列表
文章目录