def test_One(backend, M, N, K, alpha, beta, forward):
x = indigo.util.rand64c(K,N)
y = indigo.util.rand64c(M,N)
B = backend()
if getattr(B.onemm, '__isabstractmethod__', False):
pytest.skip("backed <%s> doesn't implement onemm" % backend.__name__)
if not hasattr(B, 'onemm'):
pytest.skip("backend doesn't implement onemm")
O = B.One((M,K), dtype=np.complex64)
if forward:
u, v = x, y
else:
v, u = x, y
u_d = B.copy_array(u)
v_d = B.copy_array(v)
exp = beta * v + \
np.broadcast_to(alpha*u.sum(axis=0,keepdims=True), v.shape)
O.eval(v_d, u_d, alpha=alpha, beta=beta, forward=forward)
act = v_d.to_host()
np.testing.assert_allclose(act, exp, rtol=1e-5)
评论列表
文章目录