def get_model_and_optimizer(result_dir, modelfn, opt, opt_kwargs, net_kwargs, gpu):
model_fn = os.path.basename(modelfn)
model_name = model_fn.split('.')[0]
module = imp.load_source(model_name, modelfn)
Net = getattr(module, model_name)
dst = '%s/%s' % (result_dir, model_fn)
if not os.path.exists(dst):
shutil.copy(modelfn, dst)
dst = '%s/%s' % (result_dir, os.path.basename(__file__))
if not os.path.exists(dst):
shutil.copy(__file__, dst)
# prepare model
model = Net(**net_kwargs)
if gpu >= 0:
model.to_gpu()
optimizer = optimizers.__dict__[opt](**opt_kwargs)
optimizer.setup(model)
return model, optimizer
评论列表
文章目录