dfutil.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:TensorFlowOnSpark 作者: yahoo 项目源码 文件源码
def loadTFRecords(sc, input_dir, binary_features=[]):
  """Load TFRecords from disk into a Spark DataFrame.

  This will attempt to automatically convert the tf.train.Example features into Spark DataFrame columns of equivalent types.

  Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
  disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint"
  from the caller in the ``binary_features`` argument.

  Args:
    :sc: SparkContext
    :input_dir: location of TFRecords on disk.
    :binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays.

  Returns:
    A Spark DataFrame mirroring the tf.train.Example schema.
  """
  import tensorflow as tf

  tfr_rdd = sc.newAPIHadoopFile(input_dir, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
                              keyClass="org.apache.hadoop.io.BytesWritable",
                              valueClass="org.apache.hadoop.io.NullWritable")

  # infer Spark SQL types from tf.Example
  record = tfr_rdd.take(1)[0]
  example = tf.train.Example()
  example.ParseFromString(bytes(record[0]))
  schema = infer_schema(example, binary_features)

  # convert serialized protobuf to tf.Example to Row
  example_rdd = tfr_rdd.mapPartitions(lambda x: fromTFExample(x, binary_features))

  # create a Spark DataFrame from RDD[Row]
  df = example_rdd.toDF(schema)

  # save reference of this dataframe
  loadedDF[df] = input_dir
  return df
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号