def infer_schema(example, binary_features=[]):
"""Given a tf.train.Example, infer the Spark DataFrame schema (StructFields).
Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint"
from the caller in the ``binary_features`` argument.
Args:
:example: a tf.train.Example
:binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays.
Returns:
A DataFrame StructType schema
"""
def _infer_sql_type(k, v):
# special handling for binary features
if k in binary_features:
return BinaryType()
if v.int64_list.value:
result = v.int64_list.value
sql_type = LongType()
elif v.float_list.value:
result = v.float_list.value
sql_type = DoubleType()
else:
result = v.bytes_list.value
sql_type = StringType()
if len(result) > 1: # represent multi-item tensors as Spark SQL ArrayType() of base types
return ArrayType(sql_type)
else: # represent everything else as base types (and empty tensors as StringType())
return sql_type
return StructType([ StructField(k, _infer_sql_type(k, v), True) for k,v in sorted(example.features.feature.items()) ])
评论列表
文章目录