def test_sign():
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)
data_tmp[:]=5
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:]=3
test = mx.sym.sign(data)
exe_test = test.bind(mx.cpu(), args=[arr_data], args_grad=[arr_grad])
exe_test.forward()
out = exe_test.outputs[0].asnumpy()
npout = np.sign(data_tmp)
assert reldiff(out, npout) < 1e-6
out_grad = mx.nd.empty(shape)
out_grad[:] = 2;
npout_grad = out_grad.asnumpy()
npout_grad = 0;
exe_test.backward(out_grad)
assert reldiff(arr_grad.asnumpy(), npout_grad) < 1e-6
评论列表
文章目录