def parse_csv(schema, instances, prediction):
"""A wrapper around decode_csv that parses csv instances based on provided Schema information.
"""
if prediction:
# For training and evaluation data, the expectation is the target column is always present.
# For prediction however, the target may or may not be present.
# - In true prediction use-cases, the target is unknown and never present.
# - In prediction for model evaluation use-cases, the target is present.
# To use a single prediction graph, the missing target needs to be detected by comparing
# number of columns in instances with number of columns defined in the schema. If there are
# fewer columns, then prepend a ',' (with assumption that target is always the first column).
#
# To get the number of columns in instances, split on the ',' on the first instance, and use
# the first dimension of the shape of the resulting substring values.
columns = tf.shape(tf.string_split([instances[0]], delimiter=',').values)[0]
instances = tf.cond(tf.less(columns, len(schema)),
lambda: tf.string_join([tf.constant(','), instances]),
lambda: instances)
# Convert the schema into a set of tensor defaults, to be used for parsing csv data.
defaults = []
for field in schema:
if field.length != 1:
# TODO: Support variable length, and list columns in csv.
raise ValueError('Unsupported schema field "%s". Length must be 1.' % field.name)
if field.type == SchemaFieldType.integer:
field_default = tf.constant(0, dtype=tf.int64)
elif field.type == SchemaFieldType.real:
field_default = tf.constant(0.0, dtype=tf.float32)
else:
# discrete, text, binary
field_default = tf.constant('', dtype=tf.string)
defaults.append([field_default])
values = tf.decode_csv(instances, defaults, name='csv')
parsed_instances = {}
for field, value in zip(schema, values):
# The parsed values are scalars, so each tensor is of shape (None,); turn them into tensors
# of shape (None, 1).
parsed_instances[field.name] = tf.expand_dims(value, axis=1, name=field.name)
return parsed_instances
评论列表
文章目录