tf-keras-skeleton.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:LIE 作者: EmbraceLife 项目源码 文件源码
def get_updates(self, params, constraints, loss):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        t = self.iterations + 1

        # Due to the recommendations in [2], i.e. warming momentum schedule
        momentum_cache_t = self.beta_1 * (1. - 0.5 *
                                          (K.pow(0.96, t * self.schedule_decay)))
        momentum_cache_t_1 = self.beta_1 * (1. - 0.5 *
                                            (K.pow(0.96,
                                                   (t + 1) * self.schedule_decay)))
        m_schedule_new = self.m_schedule * momentum_cache_t
        m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
        self.updates.append((self.m_schedule, m_schedule_new))

        shapes = [K.int_shape(p) for p in params]
        ms = [K.zeros(shape) for shape in shapes]
        vs = [K.zeros(shape) for shape in shapes]

        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):
          # the following equations given in [1]
          g_prime = g / (1. - m_schedule_new)
          m_t = self.beta_1 * m + (1. - self.beta_1) * g
          m_t_prime = m_t / (1. - m_schedule_next)
          v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g)
          v_t_prime = v_t / (1. - K.pow(self.beta_2, t))
          m_t_bar = (
              1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

          self.updates.append(K.update(m, m_t))
          self.updates.append(K.update(v, v_t))

          p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
          new_p = p_t

          # apply constraints
          if p in constraints:
            c = constraints[p]
            new_p = c(new_p)
          self.updates.append(K.update(p, new_p))
        return self.updates
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号