solvers.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:DeepEnhancer 作者: minxueric 项目源码 文件源码
def adam(loss, params, learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8):
    grads = T.grad(loss, params)
    updates = OrderedDict()
    t_prev = theano.shared(np.cast[theano.config.floatX](0))
    t = t_prev + 1
    a_t = learning_rate * T.sqrt(1-beta2**t)/(1-beta1**t)
    for param, grad in zip(params, grads):
        value = param.get_value(borrow=True)
        m_prev = theano.shared(
                np.zeros(value.shape, dtype=value.dtype),
                broadcastable=param.broadcastable)
        v_prev = theano.shared(
                np.zeros(value.shape, dtype=value.dtype),
                broadcastable=param.broadcastable)
        m_t = beta1 * m_prev + (1 - beta1) * grad
        v_t = beta2 * v_prev + (1 - beta2) * grad ** 2
        step = a_t * m_t / (T.sqrt(v_t) + epsilon)

        updates[m_prev] = m_t
        updates[v_prev] = v_t
        updates[param] = param - step
    updates[t_prev] = t
    return updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号