def test_repeat_computation(hetr_device, config):
if hetr_device == 'gpu':
pytest.xfail("enable after gpu exgraph")
device_id = config['device_id']
axes = config['axes']
parallel_axis = config['parallel_axis']
with ng.metadata(device=hetr_device):
x = ng.placeholder(axes=axes)
with ng.metadata(device_id=device_id, parallel=parallel_axis):
x_plus_one = x + 1
np_x = np.random.randint(100, size=axes.lengths)
with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as transformer:
comp = transformer.computation(x_plus_one, x)
comp2 = transformer.computation(x_plus_one, x)
res = comp(np_x)
np.testing.assert_array_equal(res, np_x + 1)
res2 = comp2(np_x)
np.testing.assert_array_equal(res2, np_x + 1)
评论列表
文章目录