task.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:kaggle-youtube-8m 作者: liufuyang 项目源码 文件源码
def build_and_run_exports(latest, job_dir, name, serving_input_fn, hidden_units):
  """Given the latest checkpoint file export the saved model.

  Args:
    latest (string): Latest checkpoint file
    job_dir (string): Location of checkpoints and model files
    name (string): Name of the checkpoint to be exported. Used in building the
      export path.
    hidden_units (list): Number of hidden units
    learning_rate (float): Learning rate for the SGD
  """

  prediction_graph = tf.Graph()
  exporter = tf.saved_model.builder.SavedModelBuilder(
      os.path.join(job_dir, 'export', name))
  with prediction_graph.as_default():
    features, inputs_dict = serving_input_fn()
    prediction_dict = model.model_fn(
        model.PREDICT,
        features,
        None,  # labels
        hidden_units=hidden_units,
        learning_rate=None  # learning_rate unused in prediction mode
    )
    saver = tf.train.Saver()

    inputs_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in inputs_dict.iteritems()
    }
    output_info = {
        name: tf.saved_model.utils.build_tensor_info(tensor)
        for name, tensor in prediction_dict.iteritems()
    }
    signature_def = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs_info,
        outputs=output_info,
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
    )


  with tf.Session(graph=prediction_graph) as session:
    session.run([tf.local_variables_initializer(), tf.tables_initializer()])
    saver.restore(session, latest)
    exporter.add_meta_graph_and_variables(
        session,
        tags=[tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
        },
    )

  exporter.save()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号