def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0):
defaults = dict(lr=lr, lr_decay=lr_decay, weight_decay=weight_decay)
super(Adagrad, self).__init__(params, defaults)
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['sum'] = torch.zeros_like(p.data)
评论列表
文章目录