def get_optimizer(model, exp_name):
'''
create oprimizer based on parameters loaded from config
'''
cfg = config.load_config_file(exp_name)
optimizer_name = cfg['optimizer']
optimizer_method = getattr(torch.optim, optimizer_name)
optimizer = optimizer_method(
model.parameters(),
lr=cfg['learning_rate'],
momentum=cfg['momentum'],
weight_decay=cfg['weight_decay']
)
return optimizer
评论列表
文章目录