def decrease_learning_rate(opt, factor, final_value):
if isinstance(opt, optimizers.NesterovAG):
if opt.lr <= final_value:
return final_value
opt.lr *= factor
return
if isinstance(opt, optimizers.SGD):
if opt.lr <= final_value:
return final_value
opt.lr *= factor
return
if isinstance(opt, optimizers.MomentumSGD):
if opt.lr <= final_value:
return final_value
opt.lr *= factor
return
if isinstance(opt, optimizers.Adam):
if opt.alpha <= final_value:
return final_value
opt.alpha *= factor
return
raise NotImplementedError()
评论列表
文章目录