def compute_embeddings(images):
"""Runs inference on an image.
Args:
image: Image file names.
Returns:
Dict mapping image file name to embedding.
"""
# Creates graph from saved GraphDef.
create_graph()
filename_to_emb = {}
config = tf.ConfigProto(device_count = {'GPU': 0})
bar = progressbar.ProgressBar(widgets=[progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage()])
with tf.Session(config=config) as sess:
i = 0
for image in bar(images):
if not tf.gfile.Exists(image):
tf.logging.fatal('File does not exist %s', image)
image_data = tf.gfile.FastGFile(image, 'rb').read()
# Some useful tensors:
# 'softmax:0': A tensor containing the normalized prediction across
# 1000 labels.
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
# float description of the image.
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
# encoding of the image.
# Runs the softmax tensor by feeding the image_data as input to the graph.
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
embedding_tensor = sess.graph.get_tensor_by_name('pool_3:0')
embedding = sess.run(embedding_tensor,
{'DecodeJpeg/contents:0': image_data})
filename_to_emb[image] = embedding.reshape(2048)
i += 1
# print(image, i, len(images))
return filename_to_emb
# temp_dir is a subdir of temp
评论列表
文章目录