def test_multi_computations(hetr_device):
if hetr_device == 'gpu':
pytest.xfail("enable after gpu exgraph")
axes_x = ng.make_axes([ax_A, ax_B])
x = ng.placeholder(axes=axes_x)
y = ng.placeholder(())
with ng.metadata(device_id=('0', '1'), parallel=ax_A):
f = x ** 2
out = y - ng.mean(f, out_axes=())
np_x = np.random.randint(10, size=axes_x.lengths)
np_y = np.random.randint(10)
with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as t:
comp = t.computation(out, x, y)
another_comp = t.computation(f, x)
res_comp = comp(np_x, np_y)
res_another_comp = another_comp(np_x)
ref_comp = np_y - np.mean(np_x**2)
np.testing.assert_array_equal(res_comp, ref_comp)
np.testing.assert_array_equal(res_another_comp, np_x**2)
评论列表
文章目录