test_constant.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_cputensor_fusion():
    """TODO."""
    M = ng.make_axis(length=1)
    N = ng.make_axis(length=3)

    np_a = np.array([[1, 2, 3]], dtype=np.float32)
    np_b = np.array([[3, 2, 1]], dtype=np.float32)
    np_d = np.multiply(np_b, np.add(np_a, 2))

    a = ng.constant(np_a, [M, N])
    b = ng.constant(np_b, [M, N])
    c = ng.constant(2)
    d = ng.multiply(b, ng.add(a, c))

    with executor(d) as ex:
        result = ex()
    print(result)
    assert np.array_equal(result, np_d)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号