test_operators.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:indigo 作者: mbdriscoll 项目源码 文件源码
def test_One(backend, M, N, K, alpha, beta, forward):
    x = indigo.util.rand64c(K,N)
    y = indigo.util.rand64c(M,N)
    B = backend()
    if getattr(B.onemm, '__isabstractmethod__', False):
        pytest.skip("backed <%s> doesn't implement onemm" % backend.__name__)
    if not hasattr(B, 'onemm'):
        pytest.skip("backend doesn't implement onemm")
    O = B.One((M,K), dtype=np.complex64)

    if forward:
        u, v = x, y
    else:
        v, u = x, y

    u_d = B.copy_array(u)
    v_d = B.copy_array(v)
    exp = beta * v + \
        np.broadcast_to(alpha*u.sum(axis=0,keepdims=True), v.shape)
    O.eval(v_d, u_d, alpha=alpha, beta=beta, forward=forward)
    act = v_d.to_host()
    np.testing.assert_allclose(act, exp, rtol=1e-5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号