def step(self):
super(Adam, self).step()
self.t += 1
if len(self.m) == 0:
for p in self.params:
self.m[p] = torch.zeros(p.size())
self.v[p] = torch.zeros(p.size())
for p in self.params:
mt = self.beta1 * self.m[p] + (1 - self.beta1) * p.grad.data
vt = self.beta2 * self.v[p] + (1 - self.beta2) * p.grad.data**2
m = mt / (1 - self.beta1**self.t)
v = vt / (1 - self.beta2**self.t)
rate = self.lr / (torch.sqrt(v) + self.epsilon)
p.data.sub_(rate * m)
self.m[p] = mt
self.v[p] = vt
self.clear_gradients()
# Alias
评论列表
文章目录