def train_bts(lang_db, imdb, max_iters = 1000, loss = 'squared_hinge'):
# Define network
model = define_network(lang_db.vector_size, loss)
#model = load_model(osp.join('output', 'bts_ckpt', 'imagenet1k_train_bts', 'glove_wiki_300_hinge_weights-03.hdf5'))
# Create callback_list.
dir_path = osp.join('output', 'bts_ckpt', imdb.name)
if not osp.exists(dir_path):
os.makedirs(dir_path)
log_dir = osp.join('output', 'bts_logs', imdb.name)
if not osp.exists(log_dir):
os.makedirs(log_dir)
ckpt_save = osp.join(dir_path, lang_db.name + "_" + loss + "_weights-{epoch:02d}.hdf5")
checkpoint = ModelCheckpoint(ckpt_save, monitor='val_loss', verbose=1, save_best_only = True)
early_stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=0, mode='auto')
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False)
callback_list = [checkpoint, early_stop, tensorboard]
model.fit_generator(load_data(imdb, lang_db),
steps_per_epoch = 5000,
epochs = max_iters,
verbose = 1,
validation_data = imdb.load_val_data(lang_db),
validation_steps = 20000, # number of images to validate on
callbacks = callback_list,
workers = 1)
model.save(osp.join(dir_path, 'model_' + imdb.name + '_' + lang_db.name + '_' + loss + '_l2.hdf5'))
评论列表
文章目录