def parse_csv(rows_string_tensor):
"""Takes the string input tensor and returns a dict of rank-2 tensors."""
# Takes a rank-1 tensor and converts it into rank-2 tensor
# Example if the data is ['csv,line,1', 'csv,line,2', ..] to
# [['csv,line,1'], ['csv,line,2']] which after parsing will result in a
# tuple of tensors: [['csv'], ['csv']], [['line'], ['line']], [[1], [2]]
columns = tf.decode_csv(
rows_string_tensor, record_defaults=CSV_COLUMN_DEFAULTS)
features = dict(zip(CSV_COLUMNS, columns))
# Remove unused columns
for col in UNUSED_COLUMNS:
features.pop(col)
return features
评论列表
文章目录