def test_runningmeanstd():
comm = MPI.COMM_WORLD
np.random.seed(0)
for (triple,axis) in [
((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0),
((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0),
((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1),
]:
x = np.concatenate(triple, axis=axis)
ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]]
ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis)
for (a1,a2) in zipsame(ms1, ms2):
print(a1, a2)
assert np.allclose(a1, a2)
print("ok!")
评论列表
文章目录