def check_regression(symbol, forward, backward):
data = mx.symbol.Variable('data')
label = mx.symbol.Variable('label')
out = symbol(data, label)
shape = (3, 1)
arr_data = mx.random.uniform(-1, 1, shape)
arr_label = mx.random.uniform(0, 1, shape[0])
arr_grad = mx.nd.empty(shape)
exec1 = out.bind(mx.cpu(),
args=[arr_data, arr_label],
args_grad={"data" : arr_grad})
exec1.forward()
out1 = exec1.outputs[0].asnumpy()
npout = forward(arr_data.asnumpy())
assert reldiff(npout, out1) < 1e-6
exec1.backward()
npout = backward(npout, arr_label.asnumpy().reshape(npout.shape))
assert reldiff(npout, arr_grad.asnumpy()) < 1e-6
评论列表
文章目录