def adam_conditional_updates(params, cost, mincost, lr=0.001, mom1=0.9, mom2=0.999): # if cost is less than mincost, don't do update
updates = []
grads = T.grad(cost, params)
t = th.shared(np.cast[th.config.floatX](1.))
for p, g in zip(params, grads):
v = th.shared(np.cast[th.config.floatX](p.get_value() * 0.))
mg = th.shared(np.cast[th.config.floatX](p.get_value() * 0.))
v_t = mom1*v + (1. - mom1)*g
mg_t = mom2*mg + (1. - mom2)*T.square(g)
v_hat = v_t / (1. - mom1 ** t)
mg_hat = mg_t / (1. - mom2 ** t)
g_t = v_hat / T.sqrt(mg_hat + 1e-8)
p_t = p - lr * g_t
updates.append((v, ifelse(cost<mincost,v,v_t)))
updates.append((mg, ifelse(cost<mincost,mg,mg_t)))
updates.append((p, ifelse(cost<mincost,p,p_t)))
updates.append((t, ifelse(cost<mincost,t,t+1)))
return updates
评论列表
文章目录