yellowfin_backup.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:YellowFin_Pytorch 作者: JianGoForIt 项目源码 文件源码
def lr_grad_norm_avg(self):
    # this is for enforcing lr * grad_norm not 
    # increasing dramatically in case of instability.
    #  Not necessary for basic use.
    global_state = self._global_state
    beta = self._beta
    if "lr_grad_norm_avg" not in global_state:
      global_state['grad_norm_squared_avg_log'] = 0.0
    global_state['grad_norm_squared_avg_log'] = \
      global_state['grad_norm_squared_avg_log'] * beta \
      + (1 - beta) * np.log(global_state['grad_norm_squared'] + eps)
    if "lr_grad_norm_avg" not in global_state:
      global_state["lr_grad_norm_avg"] = \
        0.0 * beta + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
      # we monitor the minimal smoothed ||lr * grad||
      global_state["lr_grad_norm_avg_min"] = \
        np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() )
    else:
      global_state["lr_grad_norm_avg"] = global_state["lr_grad_norm_avg"] * beta \
        + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
      global_state["lr_grad_norm_avg_min"] = \
        min(global_state["lr_grad_norm_avg_min"], 
            np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号