def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):
''' Adam optimizer '''
updates = []
if type(cost_or_grads) is not list:
grads = tf.gradients(cost_or_grads, params)
else:
grads = cost_or_grads
t = tf.Variable(1., 'adam_t')
for p, g in zip(params, grads):
mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg')
if mom1>0:
v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')
v_t = mom1*v + (1. - mom1)*g
v_hat = v_t / (1. - tf.pow(mom1,t))
updates.append(v.assign(v_t))
else:
v_hat = g
mg_t = mom2*mg + (1. - mom2)*tf.square(g)
mg_hat = mg_t / (1. - tf.pow(mom2,t))
g_t = v_hat / tf.sqrt(mg_hat + 1e-8)
p_t = p - lr * g_t
updates.append(mg.assign(mg_t))
updates.append(p.assign(p_t))
updates.append(t.assign_add(1))
return tf.group(*updates)
评论列表
文章目录