test_parallel_utils.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:cykdtree 作者: cykdtree 项目源码 文件源码
def test_parallel_pivot_value(ndim=2, npts=50):
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    if rank == 0:
        pts = np.random.rand(npts, ndim).astype('float64')
    else:
        pts = None
    total_pts = comm.bcast(pts, root=0)
    local_pts, local_idx = parallel_utils.py_parallel_distribute(pts)
    pivot_dim = ndim-1

    piv = parallel_utils.py_parallel_pivot_value(local_pts, pivot_dim)

    nmax = (7*npts/10 + 6)
    assert(np.sum(total_pts[:, pivot_dim] < piv) <= nmax)
    assert(np.sum(total_pts[:, pivot_dim] > piv) <= nmax)

    # Not equivalent because each processes does not have multiple of 5 points
    # if rank == 0:
    #     pp, idx = utils.py_pivot(total_pts, pivot_dim)
    #     np.testing.assert_approx_equal(piv, total_pts[idx[pp], pivot_dim])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号