def train_multilabel_bts(lang_db, imdb, pretrained, max_iters = 1000, loss_func = 'squared_hinge', box_method = 'random'):
# Create callback_list.
dir_path = osp.join('output', 'bts_ckpt', imdb.name)
tensor_path = osp.join(dir_path, 'log_dir')
if not osp.exists(dir_path):
os.makedirs(dir_path)
if not osp.exists(tensor_path):
os.makedirs(tensor_path)
ckpt_save = osp.join(dir_path, lang_db.name + '_multi_label_fixed_' + 'weights-{epoch:02d}.hdf5')
checkpoint = ModelCheckpoint(ckpt_save, monitor='loss', verbose=1, save_best_only=True)
early_stop = EarlyStopping(monitor='loss', min_delta=0, patience=3, verbose=0, mode='auto')
tensorboard = TensorBoard(log_dir=dir_path, histogram_freq=2000, write_graph=True, write_images=False)
callback_list = [checkpoint, early_stop, tensorboard]
pretrained.fit_generator(load_multilabel_data(imdb, lang_db, pretrained, box_method),
steps_per_epoch = 5000,
epochs = max_iters,
verbose = 1,
callbacks = callback_list,
workers = 1)
pretrained.save(osp.join(dir_path, 'model_fixed' + imdb.name + '_' + lang_db.name + '_ML_' + box_method + '_' + loss_func + '.hdf5'))
评论列表
文章目录