models.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:MIL.pytorch 作者: gujiuxiang 项目源码 文件源码
def build_mil(opt):
    opt.n_gpus = getattr(opt, 'n_gpus', 1)

    if 'resnet101' in opt.model:
        mil_model = resnet_mil(opt)
    else:
        mil_model = vgg_mil(opt)

    if opt.n_gpus>1:
        print('Construct multi-gpu model ...')
        model = nn.DataParallel(mil_model, device_ids=opt.gpus, dim=0)
    else:
        model = mil_model
    # check compatibility if training is continued from previously saved model
    if len(opt.start_from) != 0:
        # check if all necessary files exist
        assert os.path.isdir(opt.start_from), " %s must be a a path" % opt.start_from
        lm_info_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.infos-best.pkl')
        lm_pth_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.model-best.pth')
        assert os.path.isfile(lm_info_path), "infos.pkl file does not exist in path %s" % opt.start_from
        model.load_state_dict(torch.load(lm_pth_path))
    model.cuda()
    model.train()  # Assure in training mode
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号