p_index.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:LSDMap 作者: ClementiGroup 项目源码 文件源码
def get_idxs_thread(comm, npoints):
    """ Get indices for processor using Scatterv

    Note: 
    -----
        Uppercase mpi4py functions require everything to be in C-compatible
    types or they will return garbage!
    """

    size = comm.Get_size()
    rank = comm.Get_rank()

    npoints_thread = np.zeros(size,dtype=np.intc)
    offsets_thread = np.zeros(size,dtype=np.intc)

    for idx in range(size):
        npoints_thread[idx] = npoints/size
        offsets_thread[idx] = sum(npoints_thread[:idx])

    for idx in range(npoints % size):
        npoints_thread[idx] += 1
        offsets_thread[idx + 1:] += 1

    npoints_thread = tuple(npoints_thread)
    offsets_thread = tuple(offsets_thread)

    idxs_thread = np.zeros(npoints_thread[rank],dtype=np.intc)
    idxs = np.arange(npoints,dtype=np.intc)

    comm.Scatterv((idxs, npoints_thread, offsets_thread, MPI.INT), idxs_thread, root=0)
    return idxs_thread, npoints_thread, offsets_thread
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号