def get_mod(ags):
dst = os.path.join(ags.wpath, ags.versn)
b_scr = -1
if ags.optim == 'adam':
opt = adam(ags.lrate)
elif ags.optim == 'sgd':
opt = sgd(ags.lrate)
else:
opt = adam()
lst = [build_mod2(), build_mod3(), build_mod7(), build_mod9(), build_mod11(), build_mod12(), build_mod13()]
model = lst[ags.mtype]
if ags.mtype == 0:
model = build_mod2(opt)
logging.info('start with model 2')
elif ags.mtype == 1:
model = build_mod3(opt)
logging.info('start with model 3')
elif ags.mtype == 2:
model = build_mod7(opt)
logging.info('start with model 7')
elif ags.mtype == 3:
model = build_mod9(opt)
logging.info('start with model 9')
elif ags.mtype == 4:
model = build_mod11(opt)
logging.info('start with model 11')
elif ags.mtype == 5:
model = build_mod12(opt)
logging.info('start with model 12')
elif ags.mtype == 6:
model = build_mod13(opt)
logging.info('start with model 13')
if ags.begin == -1:
fls = sorted(glob.glob(dst + '/*h5'))
if len(fls) > 0:
logging.info('load weights: %s' % fls[-1])
model.load_weights(fls[-1])
b_scr = float(os.path.basename(fls[-1]).split('_')[0])
return model, b_scr
评论列表
文章目录