yellowfin.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:YellowFin_Pytorch 作者: JianGoForIt 项目源码 文件源码
def dist_to_opt(self):
    global_state = self._global_state
    beta = self._beta
    if self._iter == 0:
      global_state["grad_norm_avg"] = 0.0
      global_state["dist_to_opt_avg"] = 0.0
    global_state["grad_norm_avg"] = \
      global_state["grad_norm_avg"] * beta + (1 - beta) * math.sqrt(global_state["grad_norm_squared"] )
    global_state["dist_to_opt_avg"] = \
      global_state["dist_to_opt_avg"] * beta \
      + (1 - beta) * global_state["grad_norm_avg"] / (global_state['grad_norm_squared_avg'] + eps)
    if self._zero_debias:
      debias_factor = self.zero_debias_factor()
      self._dist_to_opt = global_state["dist_to_opt_avg"] / debias_factor
    else:
      self._dist_to_opt = global_state["dist_to_opt_avg"]
    if self._sparsity_debias:
      self._dist_to_opt /= (np.sqrt(self._sparsity_avg) + eps)
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号