yellowfin.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:pytorch-planet-amazon 作者: rwightman 项目源码 文件源码
def after_apply(self):
    # compute running average of gradient and norm of gradient
    beta = self._beta
    global_state = self._global_state
    if self._iter == 0:
      global_state["grad_norm_squared_avg"] = 0.0

    global_state["grad_norm_squared"] = 0.0
    for group in self._optimizer.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue
        grad = p.grad.data
        # global_state['grad_norm_squared'] += torch.dot(grad, grad)
        global_state['grad_norm_squared'] += torch.sum(grad * grad)

    global_state['grad_norm_squared_avg'] = \
      global_state['grad_norm_squared_avg'] * beta + (1 - beta) * global_state['grad_norm_squared']
    # global_state['grad_norm_squared_avg'].mul_(beta).add_(1 - beta, global_state['grad_norm_squared'] )

    self.curvature_range()
    self.grad_variance()
    self.dist_to_opt()
    if self._iter > 0:
      self.get_mu()    
      self.get_lr()
      self._lr = beta * self._lr + (1 - beta) * self._lr_t
      self._mu = beta * self._mu + (1 - beta) * self._mu_t
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号