def loadTFRecords(sc, input_dir, binary_features=[]):
"""Load TFRecords from disk into a Spark DataFrame.
This will attempt to automatically convert the tf.train.Example features into Spark DataFrame columns of equivalent types.
Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint"
from the caller in the ``binary_features`` argument.
Args:
:sc: SparkContext
:input_dir: location of TFRecords on disk.
:binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays.
Returns:
A Spark DataFrame mirroring the tf.train.Example schema.
"""
import tensorflow as tf
tfr_rdd = sc.newAPIHadoopFile(input_dir, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
keyClass="org.apache.hadoop.io.BytesWritable",
valueClass="org.apache.hadoop.io.NullWritable")
# infer Spark SQL types from tf.Example
record = tfr_rdd.take(1)[0]
example = tf.train.Example()
example.ParseFromString(bytes(record[0]))
schema = infer_schema(example, binary_features)
# convert serialized protobuf to tf.Example to Row
example_rdd = tfr_rdd.mapPartitions(lambda x: fromTFExample(x, binary_features))
# create a Spark DataFrame from RDD[Row]
df = example_rdd.toDF(schema)
# save reference of this dataframe
loadedDF[df] = input_dir
return df
评论列表
文章目录