dfutil.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:TensorFlowOnSpark 作者: yahoo 项目源码 文件源码
def infer_schema(example, binary_features=[]):
  """Given a tf.train.Example, infer the Spark DataFrame schema (StructFields).

  Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
  disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint"
  from the caller in the ``binary_features`` argument.

  Args:
    :example: a tf.train.Example
    :binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays.

  Returns:
    A DataFrame StructType schema
  """
  def _infer_sql_type(k, v):
    # special handling for binary features
    if k in binary_features:
      return BinaryType()

    if v.int64_list.value:
      result = v.int64_list.value
      sql_type = LongType()
    elif v.float_list.value:
      result = v.float_list.value
      sql_type = DoubleType()
    else:
      result = v.bytes_list.value
      sql_type = StringType()

    if len(result) > 1:             # represent multi-item tensors as Spark SQL ArrayType() of base types
      return ArrayType(sql_type)
    else:                           # represent everything else as base types (and empty tensors as StringType())
      return sql_type

  return StructType([ StructField(k, _infer_sql_type(k, v), True) for k,v in sorted(example.features.feature.items()) ])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号