def update(self, params, grads):
# init cache and delta
if self.cache is None:
self.cache = [_zero(p.shape) for p in params]
if self.delta is None:
self.delta = [_zero(p.shape) for p in params]
# update parameters
for i, (c, d, p, g) in enumerate(zip(self.cache, self.delta, params, grads)):
c = self.rho * c + (1 - self.rho) * np.power(g, 2)
update = g * np.sqrt(d + self.epsilon) / np.sqrt(c + self.epsilon)
p -= self.lr * update
d = self.rho * d + (1 - self.rho) * np.power(update, 2)
self.cache[i] = c
self.delta[i] = d
评论列表
文章目录