def extract_features(self,
checkpoint_path,
inputs,
layer_name,
num_classes=0):
"""Restore model parameters from checkpoint_path. Search in the model
the layer with name `layer_name`. If found places `inputs` as input to the model
and returns the values extracted by the layer.
Args:
checkpoint_path: path of the trained model checkpoint directory
inputs: a Tensor with a shape compatible with the model's input
layer_name: a string, the name of the layer to extract from model
num_classes: number of classes to classify, this number must be equal to the number
of classes the classifier was trained on, if the model is a classifier or however is
a model class aware, otherwise let the number = 0
Returns:
features: a numpy ndarray that contains the extracted features
"""
# Evaluate the inputs in the current default graph
# then user a placeholder to inject the computed values into the new graph
with tf.Session(config=tf.ConfigProto(
allow_soft_placement=True)) as sess:
evaluated_inputs = sess.run(inputs)
# Create a new graph to not making dirty the default graph after subsequent
# calls
with tf.Graph().as_default() as graph:
inputs_ = tf.placeholder(inputs.dtype, shape=inputs.shape)
# Build a Graph that computes the predictions from the inference model.
_ = self._model.get(
inputs_, num_classes, train_phase=False, l2_penalty=0.0)
# This will raise an exception if layer_name is not found
layer = graph.get_tensor_by_name(layer_name)
saver = tf.train.Saver(variables_to_restore())
init = [
tf.variables_initializer(
tf.global_variables() + tf.local_variables()),
tf.tables_initializer()
]
features = np.zeros(layer.shape)
with tf.Session(config=tf.ConfigProto(
allow_soft_placement=True)) as sess:
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('[!] No checkpoint file found')
return features
sess.run(init)
features = sess.run(
layer, feed_dict={
inputs_: evaluated_inputs
})
return features
评论列表
文章目录