Evaluator.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:dynamic-training-bench 作者: galeone 项目源码 文件源码
def extract_features(self,
                         checkpoint_path,
                         inputs,
                         layer_name,
                         num_classes=0):
        """Restore model parameters from checkpoint_path. Search in the model
        the layer with name `layer_name`. If found places `inputs` as input to the model
        and returns the values extracted by the layer.
        Args:
            checkpoint_path: path of the trained model checkpoint directory
            inputs: a Tensor with a shape compatible with the model's input
            layer_name: a string, the name of the layer to extract from model
            num_classes: number of classes to classify, this number must be equal to the number
            of classes the classifier was trained on, if the model is a classifier or however is
            a model class aware, otherwise let the number = 0
        Returns:
            features: a numpy ndarray that contains the extracted features
        """

        # Evaluate the inputs in the current default graph
        # then user a placeholder to inject the computed values into the new graph
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            evaluated_inputs = sess.run(inputs)

        # Create a new graph to not making dirty the default graph after subsequent
        # calls
        with tf.Graph().as_default() as graph:
            inputs_ = tf.placeholder(inputs.dtype, shape=inputs.shape)

            # Build a Graph that computes the predictions from the inference model.
            _ = self._model.get(
                inputs_, num_classes, train_phase=False, l2_penalty=0.0)

            # This will raise an exception if layer_name is not found
            layer = graph.get_tensor_by_name(layer_name)

            saver = tf.train.Saver(variables_to_restore())
            init = [
                tf.variables_initializer(
                    tf.global_variables() + tf.local_variables()),
                tf.tables_initializer()
            ]
            features = np.zeros(layer.shape)
            with tf.Session(config=tf.ConfigProto(
                    allow_soft_placement=True)) as sess:
                ckpt = tf.train.get_checkpoint_state(checkpoint_path)
                if ckpt and ckpt.model_checkpoint_path:
                    # Restores from checkpoint
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    print('[!] No checkpoint file found')
                    return features
                sess.run(init)
                features = sess.run(
                    layer, feed_dict={
                        inputs_: evaluated_inputs
                    })

            return features
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号