def run_test_matmul_aa_dtype_shape(self, shape, dtype, axes=None, conj=False):
a = ((np.random.random(size=shape)) * 127).astype(dtype)
if axes is None:
axes = range(len(shape))
aa = a.transpose(axes)
if conj:
aa = aa.conj()
c_gold = np.matmul(aa, H(aa))
triu = np.triu_indices(shape[axes[-2]], 1)
c_gold[..., triu[0], triu[1]] = 0
a = bf.asarray(a, space='cuda')
aa = a.transpose(axes)
if conj:
aa = aa.conj()
c = bf.zeros_like(c_gold, space='cuda')
self.linalg.matmul(1, aa, None, 0, c)
c = c.copy('system')
np.testing.assert_allclose(c, c_gold, RTOL, ATOL)
评论列表
文章目录