def ngstep(x0, obj0, objgrad0, obj_and_kl_func, hvpx0_func, max_kl, damping, max_cg_iter,
enable_bt):
'''
Natural gradient step using hessian-vector products
Args:
x0: current point
obj0: objective value at x0
objgrad0: grad of objective value at x0
obj_and_kl_func: function mapping a point x to the objective and kl values
hvpx0_func: function mapping a vector v to the KL Hessian-vector product H(x0)v
max_kl: max kl divergence limit. Triggers a line search.
damping: multiple of I to mix with Hessians for Hessian-vector products
max_cg_iter: max conjugate gradient iterations for solving for natural gradient step
'''
assert x0.ndim == 1 and x0.shape == objgrad0.shape
# Solve for step direction
damped_hvp_func = lambda v: hvpx0_func(v) + damping * v
hvpop = ssl.LinearOperator(shape=(x0.shape[0], x0.shape[0]), matvec=damped_hvp_func)
step, _ = ssl.cg(hvpop, -objgrad0, maxiter=max_cg_iter)
fullstep = step / np.sqrt(.5 * step.dot(damped_hvp_func(step)) / max_kl + 1e-8)
# Line search on objective with a hard KL wall
if not enable_bt:
return x0 + fullstep, 0
def barrierobj(p):
obj, kl = obj_and_kl_func(p)
return np.inf if kl > 2 * max_kl else obj
xnew, num_bt_steps = btlinesearch(f=barrierobj, x0=x0, fx0=obj0, g=objgrad0, dx=fullstep,
accept_ratio=.1, shrink_factor=.5, max_steps=10)
return xnew, num_bt_steps
评论列表
文章目录