updates.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Synkhronos 作者: astooke 项目源码 文件源码
def adam(loss, params, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
    grad_shared_flat, flat_grad, unflat_grads = flat_unflat_grads(loss, params)
    grad_updates = [(grad_shared_flat, flat_grad)]
    t_prev = theano.shared(np.array(0, dtype=theano.config.floatX))
    one = T.constant(1)
    t = t_prev + one
    a_t = learning_rate * T.sqrt(one - beta2 ** t) / (one - beta1 ** t)
    param_updates = list()
    for p, g in zip(params, unflat_grads):
        value = p.get_value(borrow=True)
        m_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
                               broadcastable=p.broadcastable)
        v_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
                               broadcastable=p.broadcastable)
        m_t = beta1 * m_prev + (one - beta1) * g
        v_t = beta2 * v_prev + (one - beta2) * g ** 2
        step = a_t * m_t / (T.sqrt(v_t) + epsilon)
        param_updates += [(m_prev, m_t), (v_prev, v_t), (p, p - step)]
        param_updates += [(t_prev, t)]
    return grad_updates, param_updates, grad_shared_flat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号