def test_parallel_pivot_value(ndim=2, npts=50):
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
if rank == 0:
pts = np.random.rand(npts, ndim).astype('float64')
else:
pts = None
total_pts = comm.bcast(pts, root=0)
local_pts, local_idx = parallel_utils.py_parallel_distribute(pts)
pivot_dim = ndim-1
piv = parallel_utils.py_parallel_pivot_value(local_pts, pivot_dim)
nmax = (7*npts/10 + 6)
assert(np.sum(total_pts[:, pivot_dim] < piv) <= nmax)
assert(np.sum(total_pts[:, pivot_dim] > piv) <= nmax)
# Not equivalent because each processes does not have multiple of 5 points
# if rank == 0:
# pp, idx = utils.py_pivot(total_pts, pivot_dim)
# np.testing.assert_approx_equal(piv, total_pts[idx[pp], pivot_dim])
评论列表
文章目录