optimizers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Optimization 作者: tdozat 项目源码 文件源码
def _prepare(self, grads_and_vars):
    """"""

    if self._lr is None:
      sTy = 0
      sTs = 0
      yTy = 0
      for g_t, x_tm1 in grads_and_vars:
        if g_t is None:
          continue
        with ops.name_scope('update_' + x_tm1.op.name), ops.device(x_tm1.device):
          if isinstance(g_t, ops.Tensor):
            g_tm1 = self.get_slot(x_tm1, 'g')
            s_tm1 = self.get_slot(x_tm1, 's')
            y_t = (g_t-g_tm1)
            sTy += math_ops.reduce_sum(s_tm1*y_t)
            sTs += math_ops.reduce_sum(s_tm1**2)
            yTy += math_ops.reduce_sum(y_t**2)
          else:
            idxs, idxs_ = array_ops.unique(g_t.indices)
            g_t_ = math_ops.unsorted_segment_sum(g_t.values, idxs_, array_ops.size(idxs))
            g_tm1 = self.get_slot(x_tm1, 'g')
            g_tm1_ = array_ops.gather(g_tm1, idxs)
            s_tm1 = self.get_slot(x_tm1, 's')
            s_tm1_ = array_ops.gather(s_tm1, idxs)
            y_t_ = (g_t_-g_tm1_)
            sTy += math_ops.reduce_sum(s_tm1_*y_t_)
            sTs += math_ops.reduce_sum(s_tm1_**2)
            yTy += math_ops.reduce_sum(y_t_**2)
      sTy = math_ops.abs(sTy)
      self._lr = sTs / (sTy + self._eps)

  #=============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号