def Adam(grads, lr=0.0002, b1=0.1, b2=0.001, e=1e-8):
updates = []
varlist = []
i = sharedX(0.)
i_t = i + 1.
fix1 = 1. - (1. - b1)**i_t
fix2 = 1. - (1. - b2)**i_t
lr_t = lr * (T.sqrt(fix2) / fix1)
for p, g in grads.items():
m = sharedX(p.get_value() * 0., name=p.name + '_adam_optimizer_m')
v = sharedX(p.get_value() * 0., name=p.name + '_adam_optimizer_v')
m_t = (b1 * g) + ((1. - b1) * m)
v_t = (b2 * T.sqr(g)) + ((1. - b2) * v)
g_t = m_t / (T.sqrt(v_t) + e)
p_t = p - (lr_t * g_t)
updates.append((m, m_t))
updates.append((v, v_t))
updates.append((p, p_t))
varlist.append(m)
varlist.append(v)
updates.append((i, i_t))
return updates, varlist
评论列表
文章目录