numerical_gradient.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Aurora 作者: upul 项目源码 文件源码
def eval_numerical_grad(f, feed_dict, wrt, h=1e-5):
    wrt_val = feed_dict[wrt]
    grad = np.zeros_like(wrt_val)

    it = np.nditer(wrt_val, flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:
        ix = it.multi_index
        old_val = wrt_val[ix]
        wrt_val[ix] = old_val + h
        executor = Executor([f])
        feed_dict[wrt] = wrt_val

        result_plus, = executor.run(feed_shapes=feed_dict)
        wrt_val[ix] = old_val - h
        executor = Executor([f])

        feed_dict[wrt] = wrt_val
        result_minus, = executor.run(feed_shapes=feed_dict)

        grad[ix] = np.sum((result_plus - result_minus) / (2.0 * h))

        wrt_val[ix] = old_val
        feed_dict[wrt] = wrt_val
        it.iternext()
    return grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号