使用连接时,Spark迭代时间呈指数增长

发布于 2021-01-29 15:59:09

我是Spark的新手,我正在尝试实现一些迭代算法,以马尔可夫模型表示的质心进行聚类(期望最大化)。因此,我需要进行迭代和联接。

我遇到的一个问题是每次迭代的时间呈指数增长。
经过一些实验,我发现在进行迭代时,需要保留将在下一次迭代中重用的RDD,否则每次迭代火花都会创建执行计划,该计划将从开始就重新计算RDD,从而增加了计算时间。

init = sc.parallelize(xrange(10000000), 3)
init.cache()

for i in range(6):
    print i
    start = datetime.datetime.now()

    init2 = init.map(lambda n: (n, n*3))        
    init = init2.map(lambda n: n[0])
#     init.cache()

    print init.count()    
    print str(datetime.datetime.now() - start)

结果是:

0
10000000
0:00:04.283652
1
10000000
0:00:05.998830
2
10000000
0:00:08.771984
3
10000000
0:00:11.399581
4
10000000
0:00:14.206069
5
10000000
0:00:16.856993

因此,添加cache()可以帮助并使迭代时间保持不变。

init = sc.parallelize(xrange(10000000), 3)
init.cache()

for i in range(6):
    print i
    start = datetime.datetime.now()

    init2 = init.map(lambda n: (n, n*3))        
    init = init2.map(lambda n: n[0])
    init.cache()

    print init.count()    
    print str(datetime.datetime.now() - start)
0
10000000
0:00:04.966835
1
10000000
0:00:04.609885
2
10000000
0:00:04.324358
3
10000000
0:00:04.248709
4
10000000
0:00:04.218724
5
10000000
0:00:04.223368

但是在迭代中加入Join时,问题又回来了。这是我演示问题的一些简单代码。即使在每个RDD转换上进行缓存也不能解决问题:

init = sc.parallelize(xrange(10000), 3)
init.cache()

for i in range(6):
    print i
    start = datetime.datetime.now()

    init2 = init.map(lambda n: (n, n*3))
    init2.cache()

    init3 = init.map(lambda n: (n, n*2))
    init3.cache()

    init4 = init2.join(init3)
    init4.count()
    init4.cache()

    init = init4.map(lambda n: n[0])
    init.cache()

    print init.count()    
    print str(datetime.datetime.now() - start)

这是输出。如您所见,迭代时间呈指数增长:(

0
10000
0:00:00.674115
1
10000
0:00:00.833377
2
10000
0:00:01.525314
3
10000
0:00:04.194715
4
10000
0:00:08.139040
5
10000
0:00:17.852815

我将非常感谢您的帮助:)

关注者
0
被浏览
43
1 个回答
  • 面试哥
    面试哥 2021-01-29
    为面试而生,有面试问题,就找面试哥。

    总结

    一般来说,迭代算法,尤其是具有自联接或自联合的迭代算法,需要对以下内容进行控制:

    这里描述的问题是由于缺少前者造成的。在每次迭代中,分区的数量随着自连接而增加,从而导致指数模式。为了解决这个问题,您必须在每次迭代中控制分区数(请参阅下文),或使用诸如这样的全局工具spark.default.parallelism(请参阅Travis提供的答案)。通常,第一种方法通常提供更多的控制,并且不影响代码的其他部分。

    原始答案

    据我所知,这里有两个交错的问题-分区数量的增加和联接期间的开销。两者都可以轻松处理,因此请逐步进行。

    首先让我们创建一个帮助器来收集统计信息:

    import datetime
    
    def get_stats(i, init, init2, init3, init4,
           start, end, desc, cache, part, hashp):
        return {
            "i": i,
            "init": init.getNumPartitions(),
            "init1": init2.getNumPartitions(),
            "init2": init3.getNumPartitions(),
            "init4": init4.getNumPartitions(),
            "time": str(end - start),
            "timen": (end - start).seconds + (end - start).microseconds * 10 **-6,
            "desc": desc,
            "cache": cache,
            "part": part,
            "hashp": hashp
        }
    

    另一个帮助处理缓存/分区的助手

    def procRDD(rdd, cache=True, part=False, hashp=False, npart=16):
        rdd = rdd if not part else rdd.repartition(npart)
        rdd = rdd if not hashp else rdd.partitionBy(npart)
        return rdd if not cache else rdd.cache()
    

    提取管道逻辑:

    def run(init, description, cache=True, part=False, hashp=False, 
        npart=16, n=6):
        times = []
    
        for i in range(n):
            start = datetime.datetime.now()
    
            init2 = procRDD(
                    init.map(lambda n: (n, n*3)),
                    cache, part, hashp, npart)
            init3 = procRDD(
                    init.map(lambda n: (n, n*2)),
                    cache, part, hashp, npart)
    
    
            # If part set to True limit number of the output partitions
            init4 = init2.join(init3, npart) if part else init2.join(init3) 
            init = init4.map(lambda n: n[0])
    
            if cache:
                init4.cache()
                init.cache()
    
            init.count() # Force computations to get time
            end = datetime.datetime.now()
    
            times.append(get_stats(
                i, init, init2, init3, init4,
                start, end, description,
                cache, part, hashp
            ))
    
        return times
    

    并创建初始数据:

    ncores = 8
    init = sc.parallelize(xrange(10000), ncores * 2).cache()
    

    numPartitions单独进行联接操作(如果未提供参数),则根据输入RDD的分区数来调整输出中的分区数。这意味着每次迭代的分区数量都在增加。如果分区的数目很大,那么丑陋的事情就会变得很糟。您可以通过为numPartitions每次迭代提供连接或重新分区RDD的参数来处理这些问题。

    timesCachePart = sqlContext.createDataFrame(
            run(init, "cache + partition", True, True, False, ncores * 2))
    timesCachePart.select("i", "init1", "init2", "init4", "time", "desc").show()
    
    +-+-----+-----+-----+--------------+-----------------+
    |i|init1|init2|init4|          time|             desc|
    +-+-----+-----+-----+--------------+-----------------+
    |0|   16|   16|   16|0:00:01.145625|cache + partition|
    |1|   16|   16|   16|0:00:01.090468|cache + partition|
    |2|   16|   16|   16|0:00:01.059316|cache + partition|
    |3|   16|   16|   16|0:00:01.029544|cache + partition|
    |4|   16|   16|   16|0:00:01.033493|cache + partition|
    |5|   16|   16|   16|0:00:01.007598|cache + partition|
    +-+-----+-----+-----+--------------+-----------------+
    

    如您所见,当我们重新分区时,执行时间或多或少是恒定的。第二个问题是上述数据是随机分区的。为了确保联接性能,我们希望在单个分区上具有相同的键。为此,我们可以使用哈希分区程序:

    timesCacheHashPart = sqlContext.createDataFrame(
        run(init, "cache + hashpart", True, True, True, ncores * 2))
    timesCacheHashPart.select("i", "init1", "init2", "init4", "time", "desc").show()
    
    +-+-----+-----+-----+--------------+----------------+
    |i|init1|init2|init4|          time|            desc|
    +-+-----+-----+-----+--------------+----------------+
    |0|   16|   16|   16|0:00:00.946379|cache + hashpart|
    |1|   16|   16|   16|0:00:00.966519|cache + hashpart|
    |2|   16|   16|   16|0:00:00.945501|cache + hashpart|
    |3|   16|   16|   16|0:00:00.986777|cache + hashpart|
    |4|   16|   16|   16|0:00:00.960989|cache + hashpart|
    |5|   16|   16|   16|0:00:01.026648|cache + hashpart|
    +-+-----+-----+-----+--------------+----------------+
    

    执行时间与以前一样是恒定的,并且对基本分区的改进很小。

    现在,仅将缓存用作参考:

    timesCacheOnly = sqlContext.createDataFrame(
        run(init, "cache-only", True, False, False, ncores * 2))
    timesCacheOnly.select("i", "init1", "init2", "init4", "time", "desc").show()
    
    
    +-+-----+-----+-----+--------------+----------+
    |i|init1|init2|init4|          time|      desc|
    +-+-----+-----+-----+--------------+----------+
    |0|   16|   16|   32|0:00:00.992865|cache-only|
    |1|   32|   32|   64|0:00:01.766940|cache-only|
    |2|   64|   64|  128|0:00:03.675924|cache-only|
    |3|  128|  128|  256|0:00:06.477492|cache-only|
    |4|  256|  256|  512|0:00:11.929242|cache-only|
    |5|  512|  512| 1024|0:00:23.284508|cache-only|
    +-+-----+-----+-----+--------------+----------+
    

    如您所见,纯缓存版本的分区数量(init2,init3,init4)在每次迭代中都会加倍,执行时间与分区数量成正比。

    最后,如果使用哈希分区程序,我们可以检查是否可以通过大量分区来提高性能:

    timesCacheHashPart512 = sqlContext.createDataFrame(
        run(init, "cache + hashpart 512", True, True, True, 512))
    timesCacheHashPart512.select(
        "i", "init1", "init2", "init4", "time", "desc").show()
    +-+-----+-----+-----+--------------+--------------------+
    |i|init1|init2|init4|          time|                desc|
    +-+-----+-----+-----+--------------+--------------------+
    |0|  512|  512|  512|0:00:14.492690|cache + hashpart 512|
    |1|  512|  512|  512|0:00:20.215408|cache + hashpart 512|
    |2|  512|  512|  512|0:00:20.408070|cache + hashpart 512|
    |3|  512|  512|  512|0:00:20.390267|cache + hashpart 512|
    |4|  512|  512|  512|0:00:20.362354|cache + hashpart 512|
    |5|  512|  512|  512|0:00:19.878525|cache + hashpart 512|
    +-+-----+-----+-----+--------------+--------------------+
    

    改进并不是那么令人印象深刻,但是如果您的集群很小且数据很多,那么仍然值得尝试。

    我想这里带走的消息是分区问题。在某些情况下(mllibsql)为您处理它,但是如果您使用低级操作,这是您的责任。



知识点
面圈网VIP题库

面圈网VIP题库全新上线,海量真题题库资源。 90大类考试,超10万份考试真题开放下载啦

去下载看看