test_hetr_integration.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_distributed_dot_parallel_second_axis():
    pytest.xfail("'parallel' for not first axis isn't supported yet")

    H = ng.make_axis(length=4, name='height')
    N = ng.make_axis(length=8, name='batch')
    weight = ng.make_axis(length=2, name='weight')
    x = ng.placeholder(axes=[H, N])
    w = ng.placeholder(axes=[weight, H])
    with ng.metadata(device_id=('0', '1'), parallel=N):
        dot = ng.dot(w, x)

    np_x = np.random.randint(100, size=[H.length, N.length])
    np_weight = np.random.randint(100, size=[weight.length, H.length])
    with ExecutorFactory() as ex:
        computation = ex.executor(dot, x, w)
        res = computation(np_x, np_weight)
        np.testing.assert_array_equal(res, np.dot(np_weight, np_x))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号