在pyspark中检索每个DataFrame组中的前n个

发布于 2021-01-29 19:35:20

pyspark中有一个DataFrame,其数据如下:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

我期望在每个组中返回2条记录,每条记录具有相同的user_id,它们需要具有最高的得分。因此,结果应如下所示:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

我真的是pyspark的新手,有人可以给我一个代码段或门户网站有关此问题的相关文档吗?万分感谢!

关注者
0
被浏览
56
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    我相信您需要使用窗口函数基于user_id和来获得每一行的排名score,然后过滤结果以仅保留前两个值。

    from pyspark.sql.window import Window
    from pyspark.sql.functions import rank, col
    
    window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
    
    df.select('*', rank().over(window).alias('rank')) 
      .filter(col('rank') <= 2) 
      .show() 
    #+-------+---------+-----+----+
    #|user_id|object_id|score|rank|
    #+-------+---------+-----+----+
    #| user_1| object_1|    3|   1|
    #| user_1| object_2|    2|   2|
    #| user_2| object_2|    6|   1|
    #| user_2| object_1|    5|   2|
    #+-------+---------+-----+----+
    

    通常,官方编程指南是开始学习Spark的好地方。

    数据

    rdd = sc.parallelize([("user_1",  "object_1",  3), 
                          ("user_1",  "object_2",  2), 
                          ("user_2",  "object_1",  5), 
                          ("user_2",  "object_2",  2), 
                          ("user_2",  "object_2",  6)])
    df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])
    


知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看