train_bts.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Msc_Multi_label_ZeroShot 作者: thomasSve 项目源码 文件源码
def train_bts(lang_db, imdb, max_iters = 1000, loss = 'squared_hinge'):
    # Define network
    model = define_network(lang_db.vector_size, loss)

    #model = load_model(osp.join('output', 'bts_ckpt', 'imagenet1k_train_bts', 'glove_wiki_300_hinge_weights-03.hdf5'))

    # Create callback_list.
    dir_path = osp.join('output', 'bts_ckpt', imdb.name)
    if not osp.exists(dir_path):
        os.makedirs(dir_path)

    log_dir = osp.join('output', 'bts_logs', imdb.name)
    if not osp.exists(log_dir):
        os.makedirs(log_dir)

    ckpt_save = osp.join(dir_path, lang_db.name + "_" + loss + "_weights-{epoch:02d}.hdf5")
    checkpoint = ModelCheckpoint(ckpt_save, monitor='val_loss', verbose=1, save_best_only = True)
    early_stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=0, mode='auto')

    tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False)
    callback_list = [checkpoint, early_stop, tensorboard]
    model.fit_generator(load_data(imdb, lang_db),
                        steps_per_epoch = 5000,
                        epochs = max_iters,
                        verbose = 1,
                        validation_data = imdb.load_val_data(lang_db),
                        validation_steps = 20000, # number of images to validate on
                        callbacks = callback_list,
                        workers = 1)

    model.save(osp.join(dir_path, 'model_'  + imdb.name + '_' + lang_db.name + '_' + loss + '_l2.hdf5'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号