eve_optimizer.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:Deep-Learning 作者: FrankLongueira 项目源码 文件源码
def adam_updates(params, cost_or_grads, lr=0.001, B1=0.9, B2=0.999):
    ''' Adam optimizer '''
    updates = []
    if type(cost_or_grads) is not list:
        grads = tf.gradients(cost_or_grads, params)
    else:
        grads = cost_or_grads
    t = tf.Variable(1., 'adam_t')
    for p, g in zip(params, grads):
        v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')
        if B1>0:
            m = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_m')
            m_t = B1*m + (1. - B1)*g
            m_hat = m_t / (1. - tf.pow(B1,t))
            updates.append(m.assign(m_t))
        else:
            m_hat = g
        v_t = B2*v + (1. - B2)*tf.square(g)
        v_hat = v_t / (1. - tf.pow(B2,t))
        g_t = m_hat / tf.sqrt(v_hat + 1e-8)        
        p_t = p - lr * g_t
        updates.append(v.assign(v_t))
        updates.append(p.assign(p_t))
    updates.append(t.assign_add(1))
    return tf.group(*updates)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号