gradient_check.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:PyFunt 作者: dnlcrl 项目源码 文件源码
def eval_numerical_gradient_array(f, x, df, h=1e-5):
    '''
    Evaluate a numeric gradient for a function that accepts a numpy
    array and returns a numpy array.
    '''
    grad = np.zeros_like(x)
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:
        ix = it.multi_index

        oldval = x[ix]
        x[ix] = oldval + h
        pos = f(x).copy()
        x[ix] = oldval - h
        neg = f(x).copy()
        x[ix] = oldval

        grad[ix] = np.sum((pos - neg) * df) / (2 * h)
        it.iternext()
    return grad
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号