def build_mil(opt):
opt.n_gpus = getattr(opt, 'n_gpus', 1)
if 'resnet101' in opt.model:
mil_model = resnet_mil(opt)
else:
mil_model = vgg_mil(opt)
if opt.n_gpus>1:
print('Construct multi-gpu model ...')
model = nn.DataParallel(mil_model, device_ids=opt.gpus, dim=0)
else:
model = mil_model
# check compatibility if training is continued from previously saved model
if len(opt.start_from) != 0:
# check if all necessary files exist
assert os.path.isdir(opt.start_from), " %s must be a a path" % opt.start_from
lm_info_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.infos-best.pkl')
lm_pth_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.model-best.pth')
assert os.path.isfile(lm_info_path), "infos.pkl file does not exist in path %s" % opt.start_from
model.load_state_dict(torch.load(lm_pth_path))
model.cuda()
model.train() # Assure in training mode
return model
评论列表
文章目录