mpi_moments.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:baselines 作者: openai 项目源码 文件源码
def mpi_moments(x, axis=0):
    x = np.asarray(x, dtype='float64')
    newshape = list(x.shape)
    newshape.pop(axis)
    n = np.prod(newshape,dtype=int)
    totalvec = np.zeros(n*2+1, 'float64')
    addvec = np.concatenate([x.sum(axis=axis).ravel(), 
        np.square(x).sum(axis=axis).ravel(), 
        np.array([x.shape[axis]],dtype='float64')])
    MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
    sum = totalvec[:n]
    sumsq = totalvec[n:2*n]
    count = totalvec[2*n]
    if count == 0:
        mean = np.empty(newshape); mean[:] = np.nan
        std = np.empty(newshape); std[:] = np.nan
    else:
        mean = sum/count
        std = np.sqrt(np.maximum(sumsq/count - np.square(mean),0))
    return mean, std, count
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号