def test_rsqrt_cos_sin():
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)
data_tmp[:]=5
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:]=3
test = mx.sym.rsqrt(data) + mx.sym.cos(data) + mx.sym.sin(data)
exe_test = test.bind(mx.cpu(), args=[arr_data], args_grad=[arr_grad])
exe_test.forward()
out = exe_test.outputs[0].asnumpy()
npout = 1/ np.sqrt(data_tmp) + np.cos(data_tmp) + np.sin(data_tmp)
assert reldiff(out, npout) < 1e-6
out_grad = mx.nd.empty(shape)
out_grad[:] = 2;
npout_grad = out_grad.asnumpy()
npout_grad = npout_grad * -(1.0 / (2.0 * data_tmp * np.sqrt(data_tmp))) + npout_grad * -1 * np.sin(data_tmp) + npout_grad * np.cos(data_tmp)
exe_test.backward(out_grad)
assert reldiff(arr_grad.asnumpy(), npout_grad) < 1e-6
评论列表
文章目录