models.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:MIL.pytorch 作者: gujiuxiang 项目源码 文件源码
def build_optimizer(opt, model, infos):
    opt.pre_ft = getattr(opt, 'pre_ft', 1)

    #model_parameters = itertools.ifilter(lambda p: p.requires_grad, model.parameters())
    optimize = opt.optim
    if optimize == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
    elif optimize == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.999, weight_decay=0.0005)
    elif optimize == 'Adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
    elif optimize == 'Adagrad':
        optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
    elif optimize == 'Adamax':
        optimizer = torch.optim.Adamax(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
    elif optimize == 'ASGD':
        optimizer = torch.optim.ASGD(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
    elif optimize == 'LBFGS':
        optimizer = torch.optim.LBFGS(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
    elif optimize == 'RMSprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)

    infos['optimized'] = True

    # Load the optimizer
    if len(opt.start_from) != 0:
        if os.path.isfile(os.path.join(opt.start_from, opt.model_id + '.optimizer.pth')):
            optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, opt.model_id + '.optimizer.pth')))

    return optimizer, infos
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号