def test_calc_rounds():
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
# Get answers
ans_nrounds = int(np.ceil(np.log2(size))) + 1
ans_src_round = 0
curr_rank = rank
curr_size = size
while curr_rank != 0:
split_rank = parallel_utils.py_calc_split_rank(curr_size)
if curr_rank < split_rank:
curr_size = split_rank
curr_rank = curr_rank
else:
curr_size = curr_size - split_rank
curr_rank = curr_rank - split_rank
ans_src_round += 1
# Test
nrounds, src_round = parallel_utils.py_calc_rounds()
assert_equal(nrounds, ans_nrounds)
assert_equal(src_round, ans_src_round)
评论列表
文章目录