def test_FFT(FFT):
N = FFT.N
if FFT.rank == 0:
A = random(N).astype(FFT.float)
if FFT.communication == 'AlltoallN':
C = empty(FFT.global_complex_shape(), dtype=FFT.complex)
C = rfftn(A, C, axes=(0,1,2))
C[:, :, -1] = 0 # Remove Nyquist frequency
A = irfftn(C, A, axes=(0,1,2))
B2 = zeros(FFT.global_complex_shape(), dtype=FFT.complex)
B2 = rfftn(A, B2, axes=(0,1,2))
else:
A = zeros(N, dtype=FFT.float)
B2 = zeros(FFT.global_complex_shape(), dtype=FFT.complex)
atol, rtol = (1e-10, 1e-8) if FFT.float is float64 else (5e-7, 1e-4)
FFT.comm.Bcast(A, root=0)
FFT.comm.Bcast(B2, root=0)
a = zeros(FFT.real_shape(), dtype=FFT.float)
c = zeros(FFT.complex_shape(), dtype=FFT.complex)
a[:] = A[FFT.real_local_slice()]
c = FFT.fftn(a, c)
#print abs((c - B2[FFT.complex_local_slice()])/c.max()).max()
assert all(abs((c - B2[FFT.complex_local_slice()])/c.max()) < rtol)
#assert allclose(c, B2[FFT.complex_local_slice()], rtol, atol)
a = FFT.ifftn(c, a)
#print abs((a - A[FFT.real_local_slice()])/a.max()).max()
assert all(abs((a - A[FFT.real_local_slice()])/a.max()) < rtol)
#assert allclose(a, A[FFT.real_local_slice()], rtol, atol)
评论列表
文章目录