test_constant.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def test_cputensor_mlp():
    """TODO."""
    D = ng.make_axis(length=3)
    H = ng.make_axis(length=2)
    N = ng.make_axis(length=1)

    np_x = np.array([[1, 2, 3]], dtype=np.float32)
    np_w = np.array([[1, 1], [1, 1], [1, 1]], dtype=np.float32)
    np_b = np.array([1, 2], dtype=np.float32)
    np_c = np.dot(np_x, np_w) + np_b

    x = ng.constant(np_x, [N, D])
    w = ng.constant(np_w, [D, H])
    b = ng.constant(np_b, [H])
    wx = ng.dot(x, w)
    c = wx + b
    with executor(c) as ex:
        result = ex()
    print(result)
    print(np_c)
    assert np.array_equal(result, np_c)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号