def predict(model_scope, result_dir, result_file, img_features, k=1):
"""
Args:
model_scope: The variable_scope used when this model was trained.
result_dir: The full path to the folder in which the result file locates.
result_file: The file that saves the training results.
img_features: A 2-D ndarray (matrix) each row of which holds the pixels as
features of one image. One or more rows (image samples) can be requested
to be predicted at once.
k: Optional. Number of elements to be predicted.
Returns:
values and indices. Refer to tf.nn.top_k for details.
"""
with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph(os.path.join(result_dir, result_file + ".meta"))
saver.restore(sess, os.path.join(result_dir, result_file))
# Retrieve the Ops we 'remembered'.
logits = tf.get_collection(model_scope+"logits")[0]
images_placeholder = tf.get_collection(model_scope+"images")[0]
keep_prob_placeholder = tf.get_collection(model_scope+"keep_prob")[0]
# Add an Op that chooses the top k predictions. Apply softmax so that
# we can have the probabilities (percentage) in the output.
eval_op = tf.nn.top_k(tf.nn.softmax(logits), k=k)
values, indices = sess.run(eval_op, feed_dict={images_placeholder: img_features,
keep_prob_placeholder: 1.0})
return values, indices
评论列表
文章目录