def test_distributed_dot_parallel_second_axis():
pytest.xfail("'parallel' for not first axis isn't supported yet")
H = ng.make_axis(length=4, name='height')
N = ng.make_axis(length=8, name='batch')
weight = ng.make_axis(length=2, name='weight')
x = ng.placeholder(axes=[H, N])
w = ng.placeholder(axes=[weight, H])
with ng.metadata(device_id=('0', '1'), parallel=N):
dot = ng.dot(w, x)
np_x = np.random.randint(100, size=[H.length, N.length])
np_weight = np.random.randint(100, size=[weight.length, H.length])
with ExecutorFactory() as ex:
computation = ex.executor(dot, x, w)
res = computation(np_x, np_weight)
np.testing.assert_array_equal(res, np.dot(np_weight, np_x))
评论列表
文章目录