custom_ops.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:mxnet-gan 作者: tqchen 项目源码 文件源码
def test_constant():
    xpu = mx.gpu()
    shape = (2, 2, 100)
    x = mx.nd.ones(shape, ctx=xpu)
    y = mx.nd.ones(shape, ctx=xpu)
    gy = mx.nd.zeros(shape, ctx=xpu)
    X = constant(x) + mx.sym.Variable('Y')
    xexec = X.bind(xpu,
                   {'Y': y},
                   {'Y': gy})
    xexec.forward()
    np.testing.assert_allclose(
        xexec.outputs[0].asnumpy(), (x + y).asnumpy())
    xexec.backward([y])
    np.testing.assert_allclose(
        gy.asnumpy(), y.asnumpy())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号