yellowfin.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:YellowFin 作者: JianGoForIt 项目源码 文件源码
def curvature_range(self):
    # set up the curvature window
    self._curv_win = tf.Variable(
      np.zeros([self._curv_win_width, ]), dtype=tf.float32,
      name="curv_win", trainable=False)
    # we can use log smoothing for curvature range to follow trend faster
    # self._curv_win = tf.scatter_update(
    #   self._curv_win, self._global_step % self._curv_win_width,
    #   tf.log(self._grad_norm_squared + EPS))
    self._curv_win = tf.scatter_update(
      self._curv_win, self._global_step % self._curv_win_width,
      self._grad_norm_squared + EPS)
    # note here the iterations start from iteration 0
    valid_window = tf.slice(
      self._curv_win, tf.constant([0, ]), tf.expand_dims(
        tf.minimum(tf.constant(self._curv_win_width),
                   self._global_step + 1), dim=0))

    if self._h_min_log_smooth:
      self._h_min_t = tf.log(tf.reduce_min(valid_window) + EPS)
    else:
      self._h_min_t = tf.reduce_min(valid_window)
    if self._h_max_log_smooth:
      self._h_max_t = tf.log(tf.reduce_max(valid_window) + EPS)
    else:
      self._h_max_t = tf.reduce_max(valid_window)

    curv_range_ops = []
    with tf.control_dependencies([self._h_min_t, self._h_max_t] ):
      avg_op = self._moving_averager.apply(
        [self._h_min_t, self._h_max_t])
      with tf.control_dependencies([avg_op]):
        if self._h_min_log_smooth:
          self._h_min = tf.exp(
            tf.identity(self._moving_averager.average(self._h_min_t)))
        else:
          self._h_min = \
            tf.identity(self._moving_averager.average(self._h_min_t))
        if self._h_max_log_smooth:
          self._h_max = tf.exp(
            tf.identity(self._moving_averager.average(self._h_max_t)))
        else:
          self._h_max = \
            tf.identity(self._moving_averager.average(self._h_max_t))
      if self._sparsity_debias:
        self._h_min = self._h_min * self._sparsity_avg
        self._h_max = self._h_max * self._sparsity_avg
    curv_range_ops.append(avg_op)
    return curv_range_ops
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号