def test_distributed_dot(hetr_device, config):
if hetr_device == 'gpu':
pytest.xfail("Intermittent failure on jenkins for mgpu")
device_id = config['device_id']
axes_x = config['axes_x']
axes_w = config['axes_w']
parallel_axis = config['parallel_axis']
np_weight = np.ones(axes_w.lengths)
with ng.metadata(device=hetr_device):
x = ng.placeholder(axes=axes_x)
with ng.metadata(device_id=device_id, parallel=parallel_axis):
w = ng.variable(axes=axes_w, initial_value=np_weight)
dot = ng.dot(x, w)
np_x = np.random.randint(100, size=axes_x.lengths)
with closing(ngt.make_transformer_factory('hetr',
device=hetr_device)()) as transformer:
computation = transformer.computation(dot, x)
res = computation(np_x)
np.testing.assert_array_equal(res, np.dot(np_x, np_weight))
评论列表
文章目录