model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:cloudml-samples 作者: GoogleCloudPlatform 项目源码 文件源码
def export(self, last_checkpoint, output_dir):
    """Builds a prediction graph and xports the model.

    Args:
      last_checkpoint: The latest checkpoint from training.
      output_dir: Path to the folder to be used to output the model.
    """
    logging.info('Exporting prediction graph to %s', output_dir)
    with tf.Session(graph=tf.Graph()) as sess:
      # Build and save prediction meta graph and trained variable values.
      self.build_prediction_graph()
      # Remove this if once Tensorflow 0.12 is standard.
      try:
        init_op = tf.global_variables_initializer()
      except AttributeError:
        init_op = tf.initialize_all_variables()
      sess.run(init_op)
      trained_saver = tf.train.Saver()
      trained_saver.restore(sess, last_checkpoint)
      saver = tf.train.Saver()
      saver.export_meta_graph(filename=os.path.join(output_dir, 'export.meta'))
      saver.save(
          sess, os.path.join(output_dir, 'export'), write_meta_graph=False)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号