test_executor.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:mxnet_tk1 作者: starimpact 项目源码 文件源码
def test_reshape():
    x = mx.sym.Variable('x')
    y = mx.sym.FullyConnected(x, num_hidden=4)

    exe = y.simple_bind(mx.cpu(), x=(5,4))
    exe.arg_arrays[0][:] = 1
    exe.arg_arrays[1][:] = mx.nd.ones((4,4))
    exe.arg_arrays[2][:] = 0

    new_exe = exe.reshape(x=(3,4))
    new_exe.forward(is_train=False)
    # test sub exec forward
    assert np.all(new_exe.outputs[0].asnumpy() == 4)
    # test shared memory
    assert np.all(exe.outputs[0].asnumpy()[:3] == 4)
    # test base exec forward
    exe.forward(is_train=False)
    assert np.all(exe.outputs[0].asnumpy() == 4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号