def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('ASGD does not support sparse gradients')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['eta'] = group['lr']
state['mu'] = 1
state['ax'] = torch.zeros_like(p.data)
state['step'] += 1
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# decay term
p.data.mul_(1 - group['lambd'] * state['eta'])
# update parameter
p.data.add_(-state['eta'], grad)
# averaging
if state['mu'] != 1:
state['ax'].add_(p.data.sub(state['ax']).mul(state['mu']))
else:
state['ax'].copy_(p.data)
# update eta and mu
state['eta'] = (group['lr'] /
math.pow((1 + group['lambd'] * group['lr'] * state['step']), group['alpha']))
state['mu'] = 1 / max(1, state['step'] - group['t0'])
return loss
评论列表
文章目录