test_linalg.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:bifrost 作者: ledatelescope 项目源码 文件源码
def run_test_matmul_aa_dtype_shape(self, shape, dtype, axes=None, conj=False):
        a = ((np.random.random(size=shape)) * 127).astype(dtype)
        if axes is None:
            axes = range(len(shape))
        aa = a.transpose(axes)
        if conj:
            aa = aa.conj()
        c_gold = np.matmul(aa, H(aa))
        triu = np.triu_indices(shape[axes[-2]], 1)
        c_gold[..., triu[0], triu[1]] = 0
        a = bf.asarray(a, space='cuda')
        aa = a.transpose(axes)
        if conj:
            aa = aa.conj()
        c = bf.zeros_like(c_gold, space='cuda')
        self.linalg.matmul(1, aa, None, 0, c)
        c = c.copy('system')
        np.testing.assert_allclose(c, c_gold, RTOL, ATOL)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号