def ngstep(x0, obj0, objgrad0, obj_and_kl_func, hvpx0_func, max_kl, damping, max_cg_iter, enable_bt):
'''
Natural gradient step using hessian-vector products
Args:
x0: current point
obj0: objective value at x0
objgrad0: grad of objective value at x0
obj_and_kl_func: function mapping a point x to the objective and kl values
hvpx0_func: function mapping a vector v to the KL Hessian-vector product H(x0)v
max_kl: max kl divergence limit. Triggers a line search.
damping: multiple of I to mix with Hessians for Hessian-vector products
max_cg_iter: max conjugate gradient iterations for solving for natural gradient step
'''
assert x0.ndim == 1 and x0.shape == objgrad0.shape
# Solve for step direction
damped_hvp_func = lambda v: hvpx0_func(v) + damping*v
hvpop = ssl.LinearOperator(shape=(x0.shape[0], x0.shape[0]), matvec=damped_hvp_func)
step, _ = ssl.cg(hvpop, -objgrad0, maxiter=max_cg_iter)
fullstep = step / np.sqrt(.5 * step.dot(damped_hvp_func(step)) / max_kl + 1e-8)
# Line search on objective with a hard KL wall
if not enable_bt:
return x0+fullstep, 0
def barrierobj(p):
obj, kl = obj_and_kl_func(p)
return np.inf if kl > 2*max_kl else obj
xnew, num_bt_steps = btlinesearch(
f=barrierobj,
x0=x0,
fx0=obj0,
g=objgrad0,
dx=fullstep,
accept_ratio=.1, shrink_factor=.5, max_steps=10)
return xnew, num_bt_steps
评论列表
文章目录