custom_ops.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:mxnet-gan 作者: tqchen 项目源码 文件源码
def test_log_sum_exp():
    xpu = mx.gpu()
    shape = (2, 2, 100)
    axis = 2
    keepdims = True
    X = mx.sym.Variable('X')
    Y = log_sum_exp(X, axis=axis, keepdims=keepdims)
    x = mx.nd.array(np.random.normal(size=shape))
    x[:] = 1
    xgrad = mx.nd.empty(x.shape)
    exec1 = Y.bind(xpu, args = [x], args_grad = {'X': xgrad})
    exec1.forward()
    y = exec1.outputs[0]
    np.testing.assert_allclose(
        y.asnumpy(),
        np_log_sum_exp(x.asnumpy(), axis=axis, keepdims=keepdims))
    y[:] = 1
    exec1.backward([y])
    np.testing.assert_allclose(
        xgrad.asnumpy(),
        np_softmax(x.asnumpy(), axis=axis) * y.asnumpy())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号