updates.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:yadll 作者: pchavanne 项目源码 文件源码
def adam(cost, params, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, **kwargs):
    """Adam Gradient Descent
    Scale learning rates by Adaptive moment estimation

    References
    ----------
    .. [1] https://arxiv.org/pdf/1412.6980v8.pdf
    """
    gparams = T.grad(cost, params)
    updates = OrderedDict()
    t = shared_variable(to_float_X(0.))
    t_t = 1. + t
    l_r_t = learning_rate * T.sqrt(1. - beta2 ** t_t) / (1. - beta1 ** t_t)
    for param, gparam in zip(params, gparams):
        m = shared_variable(np.zeros(param.get_value(borrow=True).shape), broadcastable=param.broadcastable)
        v = shared_variable(np.zeros(param.get_value(borrow=True).shape), broadcastable=param.broadcastable)
        m_t = beta1 * m + (1. - beta1) * gparam
        v_t = beta2 * v + (1. - beta2) * T.sqr(gparam)
        updates[m] = m_t
        updates[v] = v_t
        updates[param] = param - l_r_t * m_t / (T.sqrt(v_t) + epsilon)
    updates[t] = t_t
    return updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号