def update_hyper_param(self):
for group in self._optimizer.param_groups:
group['momentum'] = self._mu_t
#group['momentum'] = max(self._mu, self._mu_t)
if self._force_non_inc_step == False:
group['lr'] = self._lr_t * self._lr_factor
# a loose clamping to prevent catastrophically large move. If the move
# is too large, we set lr to 0 and only use the momentum to move
if self._adapt_clip and (group['lr'] * np.sqrt(self._global_state['grad_norm_squared']) >= self._catastrophic_move_thresh):
group['lr'] = self._catastrophic_move_thresh / np.sqrt(self._global_state['grad_norm_squared'] + eps)
if self._verbose:
logging.warning("clip catastropic move!")
elif self._iter > self._curv_win_width:
# force to guarantee lr * grad_norm not increasing dramatically.
# Not necessary for basic use. Please refer to the comments
# in YFOptimizer.__init__ for more details
self.lr_grad_norm_avg()
debias_factor = self.zero_debias_factor()
group['lr'] = min(self._lr * self._lr_factor,
2.0 * self._global_state["lr_grad_norm_avg_min"] \
/ (np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) + eps) )
return
评论列表
文章目录