test.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:Image-Caption-Generator 作者: shagunsodhani 项目源码 文件源码
def predict(image_name,
            data_dir="/home/shagun/projects/Image-Caption-Generator/data/",
            weights_path=None,
            mode="test"):
    '''Method to predict the caption for a given image.
    weights_path is the path to the .h5 file (model)'''

    image_path = data_dir + "images/" + image_name
    vgg_model = load_vgg16()
    vgg_embedding = vgg_model.predict(
        load_image(image_path)
    )
    image_embeddings = [vgg_embedding]

    config_dict = generate_config(data_dir=data_dir,
                                  mode=mode)
    print(config_dict)

    model = create_model(config_dict=config_dict,
                         compile_model=False)

    model.load_weights(data_dir + "model/" + weights_path)

    tokenizer = get_tokenizer(config_dict=config_dict,
                              data_dir=data_dir)

    index_to_word = {v: k for k, v in tokenizer.word_index.items()}

    for image_embedding in image_embeddings:
        gen_captions(config=config_dict,
                     model=model,
                     image_embedding=image_embedding,
                     tokenizer=tokenizer,
                     num_captions=2,
                     index_to_word=index_to_word
                     )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号