def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph()
with test_graph_classification.as_default():
image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3])
with tf.variable_scope('mock_model'):
net = slim.conv2d(image, num_outputs=3, kernel_size=1, scope='layer1')
slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2')
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
save_path = self.get_temp_dir()
with self.test_session() as sess:
sess.run(init_op)
saved_model_path = saver.save(sess, save_path)
# Create tensorflow detection graph and load variables from
# classification checkpoint.
test_graph_detection = tf.Graph()
with test_graph_detection.as_default():
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=6)
inputs_shape = (2, 20, 20, 3)
inputs = tf.to_float(tf.random_uniform(
inputs_shape, minval=0, maxval=255, dtype=tf.int32))
preprocessed_inputs = model.preprocess(inputs)
prediction_dict = model.predict(preprocessed_inputs)
model.postprocess(prediction_dict)
var_map = model.restore_map(from_detection_checkpoint=False)
self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess:
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn(model.first_stage_feature_extractor_scope, var.name)
self.assertNotIn(model.second_stage_feature_extractor_scope,
var.name)
faster_rcnn_meta_arch_test_lib.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录