def update(self, x):
x = x.astype('float64')
n = int(np.prod(self.shape))
totalvec = np.zeros(n*2+1, 'float64')
addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
评论列表
文章目录