mpi_running_mean_std.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:rl-teacher 作者: nottombrown 项目源码 文件源码
def update(self, x):
        x = x.astype('float64')
        n = int(np.prod(self.shape))
        totalvec = np.zeros(n*2+1, 'float64')
        addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
        MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
        self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号