def test_multiple_gather_ops(hetr_device):
if hetr_device == 'gpu':
if 'gpu' not in ngt.transformer_choices():
pytest.skip("GPUTransformer not available")
pytest.xfail("Failure due to gather recv tensor being returned in wrong shape, "
" possible mismatch between op layout and op.tensor layout")
H = ng.make_axis(length=2, name='height')
W = ng.make_axis(length=4, name='width')
x = ng.placeholder(axes=[H, W])
with ng.metadata(device_id=('0', '1'), parallel=W):
x_plus_one = x + 1
x_mul_two = x_plus_one * 2
input = np.random.randint(100, size=x.axes.lengths)
with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as hetr:
plus = hetr.computation([x_mul_two, x_plus_one], x)
result_mul_two, result_plus_one = plus(input)
np.testing.assert_array_equal(result_plus_one, input + 1)
np.testing.assert_array_equal(result_mul_two, (input + 1) * 2)
评论列表
文章目录