test_hetr_integration.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_repeat_computation(hetr_device, config):
    if hetr_device == 'gpu':
        pytest.xfail("enable after gpu exgraph")
    device_id = config['device_id']
    axes = config['axes']
    parallel_axis = config['parallel_axis']

    with ng.metadata(device=hetr_device):
        x = ng.placeholder(axes=axes)
        with ng.metadata(device_id=device_id, parallel=parallel_axis):
            x_plus_one = x + 1

        np_x = np.random.randint(100, size=axes.lengths)
        with closing(ngt.make_transformer_factory('hetr', device=hetr_device)()) as transformer:
            comp = transformer.computation(x_plus_one, x)
            comp2 = transformer.computation(x_plus_one, x)

            res = comp(np_x)
            np.testing.assert_array_equal(res, np_x + 1)

            res2 = comp2(np_x)
            np.testing.assert_array_equal(res2, np_x + 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号