test_operators.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:indigo 作者: mbdriscoll 项目源码 文件源码
def test_batch(backend, M, N, K, density, alpha, beta, batch):
    b = backend()
    A_h = indigo.util.randM(M, N, density)
    A = b.SpMatrix(A_h, batch=batch)

    # forward
    x = b.rand_array((N,K))
    y = b.rand_array((M,K))
    y_exp = beta * y.to_host() + alpha * A_h * x.to_host()
    A.eval(y, x, alpha=alpha, beta=beta)
    npt.assert_allclose(y.to_host(), y_exp, rtol=1e-5)

    # adjoint
    x = b.rand_array((M,K))
    y = b.rand_array((N,K))
    y_exp = beta * y.to_host() + alpha * A_h.H * x.to_host()
    A.H.eval(y, x, alpha=alpha, beta=beta)
    npt.assert_allclose(y.to_host(), y_exp, rtol=1e-5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号