def __init__(self, saved_model_dir, input_schema, exclude_outputs,
tf_config):
self.saved_model_dir = saved_model_dir
self.session = tf.Session(graph=tf.Graph(), config=tf_config)
with self.session.graph.as_default():
with tf.Session(config=tf_config):
inputs, outputs = saved_transform_io.partially_apply_saved_transform(
saved_model_dir, {})
self.session.run(tf.tables_initializer())
input_schema_keys = input_schema.column_schemas.keys()
extra_input_keys = set(input_schema_keys).difference(inputs.keys())
if extra_input_keys:
raise ValueError('Input schema contained keys not in graph: %s' %
input_schema_keys)
extra_output_keys = set(exclude_outputs).difference(outputs.keys())
if extra_output_keys:
raise ValueError('Excluded outputs contained keys not in graph: %s' %
exclude_outputs)
non_excluded_output_keys = set(outputs.keys()).difference(
exclude_outputs)
self.inputs = {key: inputs[key] for key in input_schema_keys}
self.outputs = {key: outputs[key] for key in non_excluded_output_keys}
评论列表
文章目录