test_comm_nodes.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_multiple_gather_ops(hetr_device):
    if hetr_device == 'gpu':
        if 'gpu' not in ngt.transformer_choices():
            pytest.skip("GPUTransformer not available")
        pytest.xfail("Failure due to gather recv tensor being returned in wrong shape, "
                     " possible mismatch between op layout and op.tensor layout")

    H = ng.make_axis(length=2, name='height')
    W = ng.make_axis(length=4, name='width')
    x = ng.placeholder(axes=[H, W])
    with ng.metadata(device_id=('0', '1'), parallel=W):
        x_plus_one = x + 1
        x_mul_two = x_plus_one * 2

    input = np.random.randint(100, size=x.axes.lengths)
    with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as hetr:
        plus = hetr.computation([x_mul_two, x_plus_one], x)
        result_mul_two, result_plus_one = plus(input)

        np.testing.assert_array_equal(result_plus_one, input + 1)
        np.testing.assert_array_equal(result_mul_two, (input + 1) * 2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号