def get_lr(self):
self._lr_t = (1.0 - math.sqrt(self._mu_t) )**2 / (self._h_min + eps)
# slow start of lr to prevent huge lr when there is only a few iteration finished
self._lr_t = min(self._lr_t, self._lr_t * (self._iter + 1) / float(10.0 * self._curv_win_width) )
return
评论列表
文章目录