使用Pyspark计算Spark数据帧每列中非NaN条目的数量
我在Hive中加载了一个非常大的数据集。它由大约190万行和1450列组成。我需要确定每一列的“覆盖率”,即每一列具有非NaN值的行的分数。
这是我的代码:
from pyspark import SparkContext
from pyspark.sql import HiveContext
import string as string
sc = SparkContext(appName="compute_coverages") ## Create the context
sqlContext = HiveContext(sc)
df = sqlContext.sql("select * from data_table")
nrows_tot = df.count()
covgs=sc.parallelize(df.columns)
.map(lambda x: str(x))
.map(lambda x: (x, float(df.select(x).dropna().count()) / float(nrows_tot) * 100.))
在pyspark
shell中进行尝试,如果我随后执行covgs.take(10),它将返回一个相当大的错误堆栈。它说在文件中保存存在问题/usr/lib64/python2.6/pickle.py
。这是错误的最后一部分:
py4j.protocol.Py4JError: An error occurred while calling o37.__getnewargs__. Trace:
py4j.Py4JException: Method __getnewargs__([]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:333)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:342)
at py4j.Gateway.invoke(Gateway.java:252)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:207)
at java.lang.Thread.run(Thread.java:745)
如果有比我正在尝试的方法更好的方法来实现此目的,我欢迎您提出建议。但是,我不能使用熊猫,因为它在我正在使用的群集上当前不可用,并且我没有安装它的权利。
-
让我们从虚拟数据开始:
from pyspark.sql import Row row = Row("v", "x", "y", "z") df = sc.parallelize([ row(0.0, 1, 2, 3.0), row(None, 3, 4, 5.0), row(None, None, 6, 7.0), row(float("Nan"), 8, 9, float("NaN")) ]).toDF() ## +----+----+---+---+ ## | v| x| y| z| ## +----+----+---+---+ ## | 0.0| 1| 2|3.0| ## |null| 3| 4|5.0| ## |null|null| 6|7.0| ## | NaN| 8| 9|NaN| ## +----+----+---+---+
您只需要一个简单的聚合:
from pyspark.sql.functions import col, count, isnan, lit, sum def count_not_null(c, nan_as_null=False): """Use conversion between boolean and integer - False -> 0 - True -> 1 """ pred = col(c).isNotNull() & (~isnan(c) if nan_as_null else lit(True)) return sum(pred.cast("integer")).alias(c) df.agg(*[count_not_null(c) for c in df.columns]).show() ## +---+---+---+---+ ## | v| x| y| z| ## +---+---+---+---+ ## | 2| 3| 4| 4| ## +---+---+---+---+
或者,如果你想享受
NaN
一个NULL
:df.agg(*[count_not_null(c, True) for c in df.columns]).show() ## +---+---+---+---+ ## | v| x| y| z| ## +---+---+---+---+ ## | 1| 3| 4| 3| ## +---+---+---+---
您还可以利用SQL
NULL
语义来实现相同的结果,而无需创建自定义函数:df.agg(*[ count(c).alias(c) # vertical (column-wise) operations in SQL ignore NULLs for c in df.columns ]).show() ## +---+---+---+ ## | x| y| z| ## +---+---+---+ ## | 1| 2| 3| ## +---+---+---+
但这不适用于
NaNs
。如果您喜欢分数:
exprs = [(count_not_null(c) / count("*")).alias(c) for c in df.columns] df.agg(*exprs).show() ## +------------------+------------------+---+ ## | x| y| z| ## +------------------+------------------+---+ ## |0.3333333333333333|0.6666666666666666|1.0| ## +------------------+------------------+---+
要么
# COUNT(*) is equivalent to COUNT(1) so NULLs won't be an issue df.select(*[(count(c) / count("*")).alias(c) for c in df.columns]).show() ## +------------------+------------------+---+ ## | x| y| z| ## +------------------+------------------+---+ ## |0.3333333333333333|0.6666666666666666|1.0| ## +------------------+------------------+---+
相当于Scala:
import org.apache.spark.sql.Column import org.apache.spark.sql.functions.{col, isnan, sum} type JDouble = java.lang.Double val df = Seq[(JDouble, JDouble, JDouble, JDouble)]( (0.0, 1, 2, 3.0), (null, 3, 4, 5.0), (null, null, 6, 7.0), (java.lang.Double.NaN, 8, 9, java.lang.Double.NaN) ).toDF() def count_not_null(c: Column, nanAsNull: Boolean = false) = { val pred = c.isNotNull and (if (nanAsNull) not(isnan(c)) else lit(true)) sum(pred.cast("integer")) } df.select(df.columns map (c => count_not_null(col(c)).alias(c)): _*).show // +---+---+---+---+ // | _1| _2| _3| _4| // +---+---+---+---+ // | 2| 3| 4| 4| // +---+---+---+---+ df.select(df.columns map (c => count_not_null(col(c), true).alias(c)): _*).show // +---+---+---+---+ // | _1| _2| _3| _4| // +---+---+---+---+ // | 1| 3| 4| 3| // +---+---+---+---+