adaptive.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:seq2graph 作者: masterkeywikz 项目源码 文件源码
def _get_updates_for(self, param, grad):
        grad_tm1 = shared_like(param, 'grad')
        step_tm1 = shared_like(param, 'step', self.learning_rate.eval())
        test = grad * grad_tm1
        diff = TT.lt(test, 0)
        steps = step_tm1 * (TT.eq(test, 0) +
                            TT.gt(test, 0) * self.step_increase +
                            diff * self.step_decrease)
        step = TT.minimum(self.max_step, TT.maximum(self.min_step, steps))
        grad = grad - diff * grad
        yield param, param - TT.sgn(grad) * step
        yield grad_tm1, grad
        yield step_tm1, step
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号