def test_constant():
xpu = mx.gpu()
shape = (2, 2, 100)
x = mx.nd.ones(shape, ctx=xpu)
y = mx.nd.ones(shape, ctx=xpu)
gy = mx.nd.zeros(shape, ctx=xpu)
X = constant(x) + mx.sym.Variable('Y')
xexec = X.bind(xpu,
{'Y': y},
{'Y': gy})
xexec.forward()
np.testing.assert_allclose(
xexec.outputs[0].asnumpy(), (x + y).asnumpy())
xexec.backward([y])
np.testing.assert_allclose(
gy.asnumpy(), y.asnumpy())
评论列表
文章目录