def update_hyper_param(self):
for group in self._optimizer.param_groups:
group['momentum'] = self._mu
if self._force_non_inc_step == False:
group['lr'] = min(self._lr * self._lr_factor,
self._lr_grad_norm_thresh / (math.sqrt(self._global_state["grad_norm_squared"] ) + eps) )
elif self._iter > self._curv_win_width:
# force to guarantee lr * grad_norm not increasing dramatically.
# Not necessary for basic use. Please refer to the comments
# in YFOptimizer.__init__ for more details
self.lr_grad_norm_avg()
debias_factor = self.zero_debias_factor()
group['lr'] = min(self._lr * self._lr_factor,
2.0 * self._global_state["lr_grad_norm_avg_min"] \
/ (np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) + eps) )
return
评论列表
文章目录