def after_apply(self):
# compute running average of gradient and norm of gradient
beta = self._beta
global_state = self._global_state
if self._iter == 0:
global_state["grad_norm_squared_avg"] = 0.0
global_state["grad_norm_squared"] = 0.0
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
# global_state['grad_norm_squared'] += torch.dot(grad, grad)
global_state['grad_norm_squared'] += torch.sum(grad * grad)
global_state['grad_norm_squared_avg'] = \
global_state['grad_norm_squared_avg'] * beta + (1 - beta) * global_state['grad_norm_squared']
# global_state['grad_norm_squared_avg'].mul_(beta).add_(1 - beta, global_state['grad_norm_squared'] )
self.curvature_range()
self.grad_variance()
self.dist_to_opt()
if self._iter > 0:
self.get_mu()
self.get_lr()
self._lr = beta * self._lr + (1 - beta) * self._lr_t
self._mu = beta * self._mu + (1 - beta) * self._mu_t
return
评论列表
文章目录