def adam(cost, params, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, **kwargs):
"""Adam Gradient Descent
Scale learning rates by Adaptive moment estimation
References
----------
.. [1] https://arxiv.org/pdf/1412.6980v8.pdf
"""
gparams = T.grad(cost, params)
updates = OrderedDict()
t = shared_variable(to_float_X(0.))
t_t = 1. + t
l_r_t = learning_rate * T.sqrt(1. - beta2 ** t_t) / (1. - beta1 ** t_t)
for param, gparam in zip(params, gparams):
m = shared_variable(np.zeros(param.get_value(borrow=True).shape), broadcastable=param.broadcastable)
v = shared_variable(np.zeros(param.get_value(borrow=True).shape), broadcastable=param.broadcastable)
m_t = beta1 * m + (1. - beta1) * gparam
v_t = beta2 * v + (1. - beta2) * T.sqr(gparam)
updates[m] = m_t
updates[v] = v_t
updates[param] = param - l_r_t * m_t / (T.sqrt(v_t) + epsilon)
updates[t] = t_t
return updates
评论列表
文章目录