def step(self):
# Add weight decay
if self.weight_decay > 0:
for p in self.model.parameters():
p.grad.data.add_(self.weight_decay, p.data)
updates = {}
for i, m in enumerate(self.modules):
assert len(list(m.parameters())
) == 1, "Can handle only one parameter at the moment"
classname = m.__class__.__name__
p = next(m.parameters())
la = self.damping + self.weight_decay
if self.steps % self.Tf == 0:
# My asynchronous implementation exists, I will add it later.
# Experimenting with different ways to this in PyTorch.
self.d_a[m], self.Q_a[m] = torch.symeig(
self.m_aa[m], eigenvectors=True)
self.d_g[m], self.Q_g[m] = torch.symeig(
self.m_gg[m], eigenvectors=True)
self.d_a[m].mul_((self.d_a[m] > 1e-6).float())
self.d_g[m].mul_((self.d_g[m] > 1e-6).float())
if classname == 'Conv2d':
p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1)
else:
p_grad_mat = p.grad.data
v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m]
v2 = v1 / (
self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la)
v = self.Q_g[m] @ v2 @ self.Q_a[m].t()
v = v.view(p.grad.data.size())
updates[p] = v
vg_sum = 0
for p in self.model.parameters():
v = updates[p]
vg_sum += (v * p.grad.data * self.lr * self.lr).sum()
nu = min(1, math.sqrt(self.kl_clip / vg_sum))
for p in self.model.parameters():
v = updates[p]
p.grad.data.copy_(v)
p.grad.data.mul_(nu)
self.optim.step()
self.steps += 1
评论列表
文章目录