def getlearningrate(epoch, opt):
# update lr
lr = opt.LR
if opt.lrPolicy == "multistep":
if epoch + 1.0 > opt.nEpochs * opt.ratio[1]: # 0.6 or 0.8
lr = opt.LR * 0.01
elif epoch + 1.0 > opt.nEpochs * opt.ratio[0]: # 0.4 or 0.6
lr = opt.LR * 0.1
elif opt.lrPolicy == "linear":
k = (0.001-opt.LR)/math.ceil(opt.nEpochs/2.0)
lr = k*math.ceil((epoch+1)/opt.step)+opt.LR
elif opt.lrPolicy == "exp":
power = math.floor((epoch+1)/opt.step)
lr = lr*math.pow(opt.gamma, power)
elif opt.lrPolicy == "fixed":
lr = opt.LR
else:
assert False, "invalid lr policy"
return lr
评论列表
文章目录