train.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Image-Caption-Generator 作者: abi-aryan 项目源码 文件源码
def train(batch_size=128,
          epochs=100,
          data_dir="/home/shagun/projects/Image-Caption-Generator/data/",
          weights_path=None,
          mode="train"):
    '''Method to train the image caption generator
    weights_path is the path to the .h5 file where weights from the previous
    run are saved (if available)'''

    config_dict = generate_config(data_dir=data_dir,
                                  mode=mode)
    config_dict['batch_size'] = batch_size
    steps_per_epoch = config_dict["total_number_of_examples"] // batch_size

    print("steps_per_epoch = ", steps_per_epoch)
    train_data_generator = debug_generator(config_dict=config_dict,
                                           data_dir=data_dir)

    model = create_model(config_dict=config_dict)

    if weights_path:
        model.load_weights(weights_path)

    file_name = data_dir + "model/weights-{epoch:02d}.hdf5"
    checkpoint = ModelCheckpoint(filepath=file_name,
                                 monitor='loss',
                                 verbose=1,
                                 save_best_only=True,
                                 mode='min')
    tensorboard = TensorBoard(log_dir='../logs',
                              histogram_freq=0,
                              batch_size=batch_size,
                              write_graph=True,
                              write_grads=True,
                              write_images=False,
                              embeddings_freq=0,
                              embeddings_layer_names=None,
                              embeddings_metadata=None)

    callbacks_list = [checkpoint, tensorboard]
    model.fit_generator(
        generator=train_data_generator,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        verbose=2,
        callbacks=callbacks_list)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号