train_bts.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Msc_Multi_label_ZeroShot 作者: thomasSve 项目源码 文件源码
def train_multilabel_bts(lang_db, imdb, pretrained, max_iters = 1000, loss_func = 'squared_hinge', box_method = 'random'):
    # Create callback_list.
    dir_path = osp.join('output', 'bts_ckpt', imdb.name)
    tensor_path = osp.join(dir_path, 'log_dir')
    if not osp.exists(dir_path):
        os.makedirs(dir_path)
    if not osp.exists(tensor_path):
        os.makedirs(tensor_path)

    ckpt_save = osp.join(dir_path, lang_db.name + '_multi_label_fixed_' + 'weights-{epoch:02d}.hdf5')
    checkpoint = ModelCheckpoint(ckpt_save, monitor='loss', verbose=1, save_best_only=True)
    early_stop = EarlyStopping(monitor='loss', min_delta=0, patience=3, verbose=0, mode='auto')
    tensorboard = TensorBoard(log_dir=dir_path, histogram_freq=2000, write_graph=True, write_images=False)
    callback_list = [checkpoint, early_stop, tensorboard]
    pretrained.fit_generator(load_multilabel_data(imdb, lang_db, pretrained, box_method),
                             steps_per_epoch = 5000,
                             epochs = max_iters,
                             verbose = 1,
                             callbacks = callback_list,
                             workers = 1)

    pretrained.save(osp.join(dir_path, 'model_fixed' + imdb.name + '_' + lang_db.name + '_ML_' + box_method + '_' + loss_func + '.hdf5'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号