def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph()
with test_graph_classification.as_default():
image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3])
with tf.variable_scope('mock_model'):
net = slim.conv2d(image, num_outputs=32, kernel_size=1, scope='layer1')
slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2')
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
save_path = self.get_temp_dir()
with self.test_session() as sess:
sess.run(init_op)
saved_model_path = saver.save(sess, save_path)
# Create tensorflow detection graph and load variables from
# classification checkpoint.
test_graph_detection = tf.Graph()
with test_graph_detection.as_default():
inputs_shape = [2, 2, 2, 3]
inputs = tf.to_float(tf.random_uniform(
inputs_shape, minval=0, maxval=255, dtype=tf.int32))
preprocessed_inputs = self._model.preprocess(inputs)
prediction_dict = self._model.predict(preprocessed_inputs)
self._model.postprocess(prediction_dict)
var_map = self._model.restore_map(from_detection_checkpoint=False)
self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess:
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name)
评论列表
文章目录