test_hetr_integration.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_multi_computations(hetr_device):
    if hetr_device == 'gpu':
        pytest.xfail("enable after gpu exgraph")
    axes_x = ng.make_axes([ax_A, ax_B])
    x = ng.placeholder(axes=axes_x)
    y = ng.placeholder(())
    with ng.metadata(device_id=('0', '1'), parallel=ax_A):
        f = x ** 2
        out = y - ng.mean(f, out_axes=())

    np_x = np.random.randint(10, size=axes_x.lengths)
    np_y = np.random.randint(10)
    with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as t:
        comp = t.computation(out, x, y)
        another_comp = t.computation(f, x)

        res_comp = comp(np_x, np_y)
        res_another_comp = another_comp(np_x)
        ref_comp = np_y - np.mean(np_x**2)
        np.testing.assert_array_equal(res_comp, ref_comp)
        np.testing.assert_array_equal(res_another_comp, np_x**2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号