def adam_updates(params, cost_or_grads, lr=0.001, B1=0.9, B2=0.999):
''' Adam optimizer '''
updates = []
if type(cost_or_grads) is not list:
grads = tf.gradients(cost_or_grads, params)
else:
grads = cost_or_grads
t = tf.Variable(1., 'adam_t')
for p, g in zip(params, grads):
v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')
if B1>0:
m = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_m')
m_t = B1*m + (1. - B1)*g
m_hat = m_t / (1. - tf.pow(B1,t))
updates.append(m.assign(m_t))
else:
m_hat = g
v_t = B2*v + (1. - B2)*tf.square(g)
v_hat = v_t / (1. - tf.pow(B2,t))
g_t = m_hat / tf.sqrt(v_hat + 1e-8)
p_t = p - lr * g_t
updates.append(v.assign(v_t))
updates.append(p.assign(p_t))
updates.append(t.assign_add(1))
return tf.group(*updates)
评论列表
文章目录