test_hetr_integration.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_distributed_dot(hetr_device, config):
    if hetr_device == 'gpu':
        pytest.xfail("Intermittent failure on jenkins for mgpu")
    device_id = config['device_id']
    axes_x = config['axes_x']
    axes_w = config['axes_w']
    parallel_axis = config['parallel_axis']

    np_weight = np.ones(axes_w.lengths)
    with ng.metadata(device=hetr_device):
        x = ng.placeholder(axes=axes_x)
        with ng.metadata(device_id=device_id, parallel=parallel_axis):
            w = ng.variable(axes=axes_w, initial_value=np_weight)
            dot = ng.dot(x, w)

    np_x = np.random.randint(100, size=axes_x.lengths)
    with closing(ngt.make_transformer_factory('hetr',
                 device=hetr_device)()) as transformer:
        computation = transformer.computation(dot, x)
        res = computation(np_x)
        np.testing.assert_array_equal(res, np.dot(np_x, np_weight))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号