yellowfin_backup.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:YellowFin_Pytorch 作者: JianGoForIt 项目源码 文件源码
def update_hyper_param(self):
    for group in self._optimizer.param_groups:
      group['momentum'] = self._mu
      if self._force_non_inc_step == False:
        group['lr'] = min(self._lr * self._lr_factor, 
          self._lr_grad_norm_thresh / (math.sqrt(self._global_state["grad_norm_squared"] ) + eps) )
      elif self._iter > self._curv_win_width:
        # force to guarantee lr * grad_norm not increasing dramatically. 
        # Not necessary for basic use. Please refer to the comments
        # in YFOptimizer.__init__ for more details
        self.lr_grad_norm_avg()
        debias_factor = self.zero_debias_factor()
        group['lr'] = min(self._lr * self._lr_factor,
          2.0 * self._global_state["lr_grad_norm_avg_min"] \
          / (np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) + eps) )
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号