cnnpredictor.py 文件源码

python
阅读 103 收藏 0 点赞 0 评论 0

项目:DmsMsgRcg 作者: bshao001 项目源码 文件源码
def __init__(self, session, model_scope, result_dir, result_file, k=1):
        """
        Args:
            model_scope: The variable_scope used for the trained model to be restored.
            session: The TensorFlow session used to run the prediction.
            result_dir: The full path to the folder in which the result file locates.
            result_file: The file that saves the training results.
            k: Optional. Number of elements to be predicted.
        """
        tf.train.import_meta_graph(os.path.join(result_dir, result_file + ".meta"))
        all_vars = tf.global_variables()
        model_vars = [var for var in all_vars if var.name.startswith(model_scope)]
        saver = tf.train.Saver(model_vars)
        saver.restore(session, os.path.join(result_dir, result_file))

        # Retrieve the Ops we 'remembered'.
        logits = tf.get_collection(model_scope+"logits")[0]
        self.images_placeholder = tf.get_collection(model_scope+"images")[0]
        self.keep_prob_placeholder = tf.get_collection(model_scope+"keep_prob")[0]

        # Add an Op that chooses the top k predictions. Apply softmax so that
        # we can have the probabilities (percentage) in the output.
        self.eval_op = tf.nn.top_k(tf.nn.softmax(logits), k=k)
        self.session = session
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号