test_hetr_integration.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_comm_broadcast_op(hetr_device):
    if hetr_device == 'gpu':
        pytest.skip('gpu communication broadcast op is not supported.')
    H = ng.make_axis(length=4, name='height')
    N = ng.make_axis(length=8, name='batch')
    weight = ng.make_axis(length=2, name='weight')
    x = ng.placeholder(axes=[N, H])
    # w will be broadcasted to devices
    w = ng.placeholder(axes=[H, weight])
    with ng.metadata(device_id=('0', '1'), parallel=N):
        dot = ng.dot(x, w)

    np_x = np.random.randint(100, size=[N.length, H.length])
    np_weight = np.random.randint(100, size=[H.length, weight.length])
    with ExecutorFactory() as ex:
        computation = ex.executor(dot, x, w)
        res = computation(np_x, np_weight)
        np.testing.assert_array_equal(res, np.dot(np_x, np_weight))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号