CheckpointExporterHook.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:ISLES2017 作者: MiguelMonteiro 项目源码 文件源码
def _run_export(self):

        export_dir = 'export_ckpt_' + re.findall('\d+', self._latest_checkpoint)[-1]
        tf.logging.info('Exporting model from checkpoint {0}'.format(self._latest_checkpoint))
        prediction_graph = tf.Graph()
        try:
            exporter = tf.saved_model.builder.SavedModelBuilder(os.path.join(self._checkpoint_dir, export_dir))
        except IOError:
            tf.logging.info('Checkpoint {0} already exported, continuing...'.format(self._latest_checkpoint))
            return

        with prediction_graph.as_default():
            image, name, inputs_dict = model.serving_input_fn()
            prediction_dict = model.model_fn(model.PREDICT, name, image, None, 6, None)

            saver = tf.train.Saver()

            inputs_info = {name: tf.saved_model.utils.build_tensor_info(tensor)
                           for name, tensor in inputs_dict.iteritems()}

            output_info = {name: tf.saved_model.utils.build_tensor_info(tensor)
                           for name, tensor in prediction_dict.iteritems()}

            signature_def = tf.saved_model.signature_def_utils.build_signature_def(
                inputs=inputs_info,
                outputs=output_info,
                method_name=sig_constants.PREDICT_METHOD_NAME
            )

        with tf.Session(graph=prediction_graph) as session:
            saver.restore(session, self._latest_checkpoint)
            exporter.add_meta_graph_and_variables(
                session,
                tags=[tf.saved_model.tag_constants.SERVING],
                signature_def_map={sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def},
                legacy_init_op=my_main_op()
            )

        exporter.save()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号