imgconvnets.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DmsMsgRcg 作者: bshao001 项目源码 文件源码
def predict(model_scope, result_dir, result_file, img_features, k=1):
        """
        Args:
            model_scope: The variable_scope used when this model was trained.
            result_dir: The full path to the folder in which the result file locates.
            result_file: The file that saves the training results.
            img_features: A 2-D ndarray (matrix) each row of which holds the pixels as
                features of one image. One or more rows (image samples) can be requested
                to be predicted at once.
            k: Optional. Number of elements to be predicted.
        Returns:
            values and indices. Refer to tf.nn.top_k for details.
        """
        with tf.Session(graph=tf.Graph()) as sess:
            saver = tf.train.import_meta_graph(os.path.join(result_dir, result_file + ".meta"))
            saver.restore(sess, os.path.join(result_dir, result_file))

            # Retrieve the Ops we 'remembered'.
            logits = tf.get_collection(model_scope+"logits")[0]
            images_placeholder = tf.get_collection(model_scope+"images")[0]
            keep_prob_placeholder = tf.get_collection(model_scope+"keep_prob")[0]

            # Add an Op that chooses the top k predictions. Apply softmax so that
            # we can have the probabilities (percentage) in the output.
            eval_op = tf.nn.top_k(tf.nn.softmax(logits), k=k)

            values, indices = sess.run(eval_op, feed_dict={images_placeholder: img_features,
                                                           keep_prob_placeholder: 1.0})

            return values, indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号