def get_adam_updates(f, params, lr=10., b1=0.9, b2=0.999, e=1e-8, dec=5e-3, norm_grads=False):
"""Generate updates to optimize using the Adam optimizer with linear learning rate decay."""
t = theano.shared(0)
ms = [theano.shared(np.zeros(param.shape.eval(), dtype=floatX), borrow=True) for param in params]
vs = [theano.shared(np.zeros(param.shape.eval(), dtype=floatX), borrow=True) for param in params]
gs = T.grad(f, params)
if norm_grads:
gs = [g / (T.sum(T.abs_(g)) + 1e-8) for g in gs]
t_u = (t, t + 1)
m_us = [(m, b1 * m + (1. - b1) * g) for m, g in zip(ms, gs)]
v_us = [(v, b2 * v + (1. - b2) * T.sqr(g)) for v, g in zip(vs, gs)]
t_u_f = T.cast(t_u[1], floatX)
lr_hat = (lr / (1. + t_u_f * dec)) * T.sqrt(1. - T.pow(b2, t_u_f)) / (1. - T.pow(b1, t_u_f))
param_us = [(param, param - lr_hat * m_u[1] / (T.sqrt(v_u[1]) + e)) for m_u, v_u, param in zip(m_us, v_us, params)]
return m_us + v_us + param_us + [t_u]
评论列表
文章目录