update_function.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:recnet 作者: joergfranke 项目源码 文件源码
def fit(self, weights, o_error, tpo ):

        gradients = T.grad(o_error ,weights)
        updates = []
        for c, v, w, g in zip(self.t_cache, self.t_velocity, weights,gradients):
            new_velocity = T.sub( T.mul(tpo["momentum_rate"], v) , T.mul(tpo["learn_rate"], g) )
            new_cache = T.add( T.mul(tpo["decay_rate"] , c) , T.mul(T.sub( 1, tpo["decay_rate"]) , T.sqr(g)))
            new_weights = T.sub(T.add(w , new_velocity) , T.true_div( T.mul(g,tpo["learn_rate"]) , T.sqrt(T.add(new_cache,0.1**8))))
            updates.append((w, new_weights))
            updates.append((v, new_velocity))
            updates.append((c, new_cache))

        return updates


######                 Nesterov momentum
########################################
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号