def check_elementwise_sum_with_shape(shape, n):
# forward
inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
out = mx.symbol.ElementWiseSum(*inputs, name='esum')
arr = [mx.nd.empty(shape) for i in range(n)]
arr_grad = [mx.nd.empty(shape) for i in range(n)]
for i in range(n):
arr[i][:] = np.random.uniform(-10, 10, shape)
exec1 = out.bind(mx.Context('cpu'),
args=arr,
args_grad=arr_grad)
out1 = exec1.outputs[0].asnumpy()
exec1.forward()
out1 = exec1.outputs[0].asnumpy()
out = sum(a.asnumpy() for a in arr)
assert reldiff(out, out1) < 1e-6
out_grad = mx.nd.empty(shape)
out_grad[:] = np.random.uniform(-10, 10, shape)
# backward
exec1.backward([out_grad])
for a in arr_grad:
assert same(a.asnumpy(), out_grad.asnumpy())
评论列表
文章目录