yellowfin.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:YellowFin_Pytorch 作者: JianGoForIt 项目源码 文件源码
def update_hyper_param(self):
    for group in self._optimizer.param_groups:
      group['momentum'] = self._mu_t
      #group['momentum'] = max(self._mu, self._mu_t)
      if self._force_non_inc_step == False:
        group['lr'] = self._lr_t * self._lr_factor
        # a loose clamping to prevent catastrophically large move. If the move
        # is too large, we set lr to 0 and only use the momentum to move
        if self._adapt_clip and (group['lr'] * np.sqrt(self._global_state['grad_norm_squared']) >= self._catastrophic_move_thresh):
          group['lr'] = self._catastrophic_move_thresh / np.sqrt(self._global_state['grad_norm_squared'] + eps)
          if self._verbose:
            logging.warning("clip catastropic move!")
      elif self._iter > self._curv_win_width:
        # force to guarantee lr * grad_norm not increasing dramatically. 
        # Not necessary for basic use. Please refer to the comments
        # in YFOptimizer.__init__ for more details
        self.lr_grad_norm_avg()
        debias_factor = self.zero_debias_factor()
        group['lr'] = min(self._lr * self._lr_factor,
          2.0 * self._global_state["lr_grad_norm_avg_min"] \
          / (np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) + eps) )
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号