def build_optimizer(opt, model, infos):
opt.pre_ft = getattr(opt, 'pre_ft', 1)
#model_parameters = itertools.ifilter(lambda p: p.requires_grad, model.parameters())
optimize = opt.optim
if optimize == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
elif optimize == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.999, weight_decay=0.0005)
elif optimize == 'Adadelta':
optimizer = torch.optim.Adadelta(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
elif optimize == 'Adagrad':
optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
elif optimize == 'Adamax':
optimizer = torch.optim.Adamax(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
elif optimize == 'ASGD':
optimizer = torch.optim.ASGD(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
elif optimize == 'LBFGS':
optimizer = torch.optim.LBFGS(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
elif optimize == 'RMSprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
infos['optimized'] = True
# Load the optimizer
if len(opt.start_from) != 0:
if os.path.isfile(os.path.join(opt.start_from, opt.model_id + '.optimizer.pth')):
optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, opt.model_id + '.optimizer.pth')))
return optimizer, infos
评论列表
文章目录