nn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:weightnorm 作者: openai 项目源码 文件源码
def adamax_updates(params, cost, lr=0.001, mom1=0.9, mom2=0.999):
    updates = []
    grads = T.grad(cost, params)
    for p, g in zip(params, grads):
        mg = th.shared(np.cast[th.config.floatX](p.get_value() * 0.))
        v = th.shared(np.cast[th.config.floatX](p.get_value() * 0.))
        if mom1>0:
            v_t = mom1*v + (1. - mom1)*g
            updates.append((v,v_t))
        else:
            v_t = g
        mg_t = T.maximum(mom2*mg, abs(g))
        g_t = v_t / (mg_t + 1e-6)
        p_t = p - lr * g_t
        updates.append((mg, mg_t))
        updates.append((p, p_t))
    return updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号