optim.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:rltools 作者: sisl 项目源码 文件源码
def ngstep(x0, obj0, objgrad0, obj_and_kl_func, hvpx0_func, max_kl, damping, max_cg_iter,
           enable_bt):
    '''
    Natural gradient step using hessian-vector products

    Args:
        x0: current point
        obj0: objective value at x0
        objgrad0: grad of objective value at x0
        obj_and_kl_func: function mapping a point x to the objective and kl values
        hvpx0_func: function mapping a vector v to the KL Hessian-vector product H(x0)v
        max_kl: max kl divergence limit. Triggers a line search.
        damping: multiple of I to mix with Hessians for Hessian-vector products
        max_cg_iter: max conjugate gradient iterations for solving for natural gradient step
    '''

    assert x0.ndim == 1 and x0.shape == objgrad0.shape

    # Solve for step direction
    damped_hvp_func = lambda v: hvpx0_func(v) + damping * v
    hvpop = ssl.LinearOperator(shape=(x0.shape[0], x0.shape[0]), matvec=damped_hvp_func)
    step, _ = ssl.cg(hvpop, -objgrad0, maxiter=max_cg_iter)
    fullstep = step / np.sqrt(.5 * step.dot(damped_hvp_func(step)) / max_kl + 1e-8)

    # Line search on objective with a hard KL wall
    if not enable_bt:
        return x0 + fullstep, 0

    def barrierobj(p):
        obj, kl = obj_and_kl_func(p)
        return np.inf if kl > 2 * max_kl else obj

    xnew, num_bt_steps = btlinesearch(f=barrierobj, x0=x0, fx0=obj0, g=objgrad0, dx=fullstep,
                                      accept_ratio=.1, shrink_factor=.5, max_steps=10)
    return xnew, num_bt_steps
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号