def test_allreduce_hint(hetr_device, config):
if hetr_device == 'gpu':
if 'gpu' not in ngt.transformer_choices():
pytest.skip("GPUTransformer not available")
input = config['input']
device_id = config['device_id']
axis_A = ng.make_axis(length=4, name='axis_A')
parallel_axis = ng.make_axis(name='axis_parallel', length=16)
with ng.metadata(device=hetr_device,
device_id=device_id,
parallel=parallel_axis):
var_A = ng.variable(axes=[axis_A], initial_value=UniformInit(1, 1))
var_B = ng.variable(axes=[axis_A], initial_value=UniformInit(input, input))
var_B.metadata['reduce_func'] = 'sum'
var_B_mean = var_B / len(device_id)
var_minus = (var_A - var_B_mean)
with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as hetr:
out_comp = hetr.computation(var_minus)
result = out_comp()
np_result = np.full((axis_A.length), config['expected_result'], np.float32)
np.testing.assert_array_equal(result, np_result)
评论列表
文章目录