def parse_csv(rows_string_tensor):
"""Takes the string input tensor and returns a dict of rank-2 tensors."""
columns = tf.decode_csv(
rows_string_tensor, record_defaults=CSV_COLUMN_DEFAULTS)
features = dict(zip(CSV_COLUMNS, columns))
# Remove unused columns
for col in UNUSED_COLUMNS:
features.pop(col)
for key, value in six.iteritems(features):
features[key] = tf.expand_dims(features[key], -1)
return features
评论列表
文章目录