impl.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:transform 作者: tensorflow 项目源码 文件源码
def __init__(self, saved_model_dir, input_schema, exclude_outputs,
                 tf_config):
      self.saved_model_dir = saved_model_dir
      self.session = tf.Session(graph=tf.Graph(), config=tf_config)
      with self.session.graph.as_default():
        with tf.Session(config=tf_config):
          inputs, outputs = saved_transform_io.partially_apply_saved_transform(
              saved_model_dir, {})
        self.session.run(tf.tables_initializer())

        input_schema_keys = input_schema.column_schemas.keys()
        extra_input_keys = set(input_schema_keys).difference(inputs.keys())
        if extra_input_keys:
          raise ValueError('Input schema contained keys not in graph: %s' %
                           input_schema_keys)
        extra_output_keys = set(exclude_outputs).difference(outputs.keys())
        if extra_output_keys:
          raise ValueError('Excluded outputs contained keys not in graph: %s' %
                           exclude_outputs)
        non_excluded_output_keys = set(outputs.keys()).difference(
            exclude_outputs)
        self.inputs = {key: inputs[key] for key in input_schema_keys}
        self.outputs = {key: outputs[key] for key in non_excluded_output_keys}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号