def predict(image_name,
data_dir="/home/shagun/projects/Image-Caption-Generator/data/",
weights_path=None,
mode="test"):
'''Method to predict the caption for a given image.
weights_path is the path to the .h5 file (model)'''
image_path = data_dir + "images/" + image_name
vgg_model = load_vgg16()
vgg_embedding = vgg_model.predict(
load_image(image_path)
)
image_embeddings = [vgg_embedding]
config_dict = generate_config(data_dir=data_dir,
mode=mode)
print(config_dict)
model = create_model(config_dict=config_dict,
compile_model=False)
model.load_weights(data_dir + "model/" + weights_path)
tokenizer = get_tokenizer(config_dict=config_dict,
data_dir=data_dir)
index_to_word = {v: k for k, v in tokenizer.word_index.items()}
for image_embedding in image_embeddings:
gen_captions(config=config_dict,
model=model,
image_embedding=image_embedding,
tokenizer=tokenizer,
num_captions=2,
index_to_word=index_to_word
)
评论列表
文章目录