test_executor.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:mxnet_tk1 作者: starimpact 项目源码 文件源码
def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None):
    """check function consistency with uniform random numbers"""
    shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim))
    lhs = mx.symbol.Variable('lhs')
    rhs = mx.symbol.Variable('rhs')
    if sf is not None:
        ret = sf(lhs, rhs)
    else:
        ret = uf(lhs, rhs)

    assert ret.list_arguments() == ['lhs', 'rhs']
    lshape = shape if lshape is None else lshape
    rshape = shape if rshape is None else rshape

    lhs_arr = mx.nd.array(np.random.uniform(-1, 1, lshape))
    rhs_arr = mx.nd.array(np.random.uniform(-1, 1, rshape))
    lhs_grad = mx.nd.empty(lshape)
    rhs_grad = mx.nd.empty(rshape)
    executor = ret.bind(mx.Context('cpu'),
                        args=[lhs_arr, rhs_arr],
                        args_grad=[lhs_grad, rhs_grad])

    exec3 = ret.bind(mx.Context('cpu'),
                     args=[lhs_arr, rhs_arr])


    exec4 = ret.bind(mx.Context('cpu'),
                     args={'rhs': rhs_arr, 'lhs': lhs_arr},
                     args_grad={'lhs': lhs_grad, 'rhs': rhs_grad})

    executor.forward()
    exec3.forward()
    exec4.forward()
    out2 = executor.outputs[0].asnumpy()
    out1 = uf(lhs_arr.asnumpy(), rhs_arr.asnumpy())
    out3 = exec3.outputs[0].asnumpy()
    out4 = exec4.outputs[0].asnumpy()
    assert reldiff(out1, out2) < 1e-6
    assert reldiff(out1, out3) < 1e-6
    assert reldiff(out1, out4) < 1e-6
    # test gradient
    out_grad = mx.nd.array(np.ones(out2.shape))
    lhs_grad2, rhs_grad2 = gf(out_grad.asnumpy(),
                              lhs_arr.asnumpy(),
                              rhs_arr.asnumpy())
    executor.backward([out_grad])

    assert reldiff(lhs_grad.asnumpy(), lhs_grad2) < 1e-6
    assert reldiff(rhs_grad.asnumpy(), rhs_grad2) < 1e-6
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号