def update_lr(self):
# Loop over all modules
for m in self.modules():
# If a module is active:
if hasattr(m,'active') and m.active:
# If we've passed this layer's freezing point, deactivate it.
if self.j > m.max_j:
m.active = False
# Also make sure we remove all this layer from the optimizer
for i,group in enumerate(self.optim.param_groups):
if group['layer_index']==m.layer_index:
self.optim.param_groups.remove(group)
# If not, update the LR
else:
for i,group in enumerate(self.optim.param_groups):
if group['layer_index']==m.layer_index:
self.optim.param_groups[i]['lr'] = (0.05/m.lr_ratio)*(1+np.cos(np.pi*self.j/m.max_j))\
if self.scale_lr else 0.05 * (1+np.cos(np.pi*self.j/m.max_j))
self.j += 1
评论列表
文章目录