test_parallel_utils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:cykdtree 作者: cykdtree 项目源码 文件源码
def test_calc_rounds():
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    # Get answers
    ans_nrounds = int(np.ceil(np.log2(size))) + 1
    ans_src_round = 0
    curr_rank = rank
    curr_size = size
    while curr_rank != 0:
        split_rank = parallel_utils.py_calc_split_rank(curr_size)
        if curr_rank < split_rank:
            curr_size = split_rank
            curr_rank = curr_rank
        else:
            curr_size = curr_size - split_rank
            curr_rank = curr_rank - split_rank
        ans_src_round += 1
    # Test
    nrounds, src_round = parallel_utils.py_calc_rounds()
    assert_equal(nrounds, ans_nrounds)
    assert_equal(src_round, ans_src_round)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号