在pyspark中检索每个DataFrame组中的前n个
发布于 2021-01-29 19:35:20
pyspark中有一个DataFrame,其数据如下:
user_id object_id score
user_1 object_1 3
user_1 object_1 1
user_1 object_2 2
user_2 object_1 5
user_2 object_2 2
user_2 object_2 6
我期望在每个组中返回2条记录,每条记录具有相同的user_id,它们需要具有最高的得分。因此,结果应如下所示:
user_id object_id score
user_1 object_1 3
user_1 object_2 2
user_2 object_2 6
user_2 object_1 5
我真的是pyspark的新手,有人可以给我一个代码段或门户网站有关此问题的相关文档吗?万分感谢!
关注者
0
被浏览
56
1 个回答
-
我相信您需要使用窗口函数基于
user_id
和来获得每一行的排名score
,然后过滤结果以仅保留前两个值。from pyspark.sql.window import Window from pyspark.sql.functions import rank, col window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc()) df.select('*', rank().over(window).alias('rank')) .filter(col('rank') <= 2) .show() #+-------+---------+-----+----+ #|user_id|object_id|score|rank| #+-------+---------+-----+----+ #| user_1| object_1| 3| 1| #| user_1| object_2| 2| 2| #| user_2| object_2| 6| 1| #| user_2| object_1| 5| 2| #+-------+---------+-----+----+
通常,官方编程指南是开始学习Spark的好地方。
数据
rdd = sc.parallelize([("user_1", "object_1", 3), ("user_1", "object_2", 2), ("user_2", "object_1", 5), ("user_2", "object_2", 2), ("user_2", "object_2", 6)]) df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])