_ds_csv.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tensorfx 作者: TensorLab 项目源码 文件源码
def parse_csv(schema, instances, prediction):
  """A wrapper around decode_csv that parses csv instances based on provided Schema information.
  """
  if prediction:
    # For training and evaluation data, the expectation is the target column is always present.
    # For prediction however, the target may or may not be present.
    # - In true prediction use-cases, the target is unknown and never present.
    # - In prediction for model evaluation use-cases, the target is present.
    # To use a single prediction graph, the missing target needs to be detected by comparing
    # number of columns in instances with number of columns defined in the schema. If there are
    # fewer columns, then prepend a ',' (with assumption that target is always the first column).
    #
    # To get the number of columns in instances, split on the ',' on the first instance, and use
    # the first dimension of the shape of the resulting substring values.
    columns = tf.shape(tf.string_split([instances[0]], delimiter=',').values)[0]
    instances = tf.cond(tf.less(columns, len(schema)),
                        lambda: tf.string_join([tf.constant(','), instances]),
                        lambda: instances)

  # Convert the schema into a set of tensor defaults, to be used for parsing csv data.
  defaults = []
  for field in schema:
    if field.length != 1:
      # TODO: Support variable length, and list columns in csv.
      raise ValueError('Unsupported schema field "%s". Length must be 1.' % field.name)

    if field.type == SchemaFieldType.integer:
      field_default = tf.constant(0, dtype=tf.int64)
    elif field.type == SchemaFieldType.real:
      field_default = tf.constant(0.0, dtype=tf.float32)
    else:
      # discrete, text, binary
      field_default = tf.constant('', dtype=tf.string)
    defaults.append([field_default])

  values = tf.decode_csv(instances, defaults, name='csv')

  parsed_instances = {}
  for field, value in zip(schema, values):
    # The parsed values are scalars, so each tensor is of shape (None,); turn them into tensors
    # of shape (None, 1).
    parsed_instances[field.name] = tf.expand_dims(value, axis=1, name=field.name)

  return parsed_instances
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号