def test_reshape():
x = mx.sym.Variable('x')
y = mx.sym.FullyConnected(x, num_hidden=4)
exe = y.simple_bind(mx.cpu(), x=(5,4))
exe.arg_arrays[0][:] = 1
exe.arg_arrays[1][:] = mx.nd.ones((4,4))
exe.arg_arrays[2][:] = 0
new_exe = exe.reshape(x=(3,4))
new_exe.forward(is_train=False)
# test sub exec forward
assert np.all(new_exe.outputs[0].asnumpy() == 4)
# test shared memory
assert np.all(exe.outputs[0].asnumpy()[:3] == 4)
# test base exec forward
exe.forward(is_train=False)
assert np.all(exe.outputs[0].asnumpy() == 4)
评论列表
文章目录