test_FFT.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:mpiFFT4py 作者: spectralDNS 项目源码 文件源码
def test_FFT(FFT):
    N = FFT.N
    if FFT.rank == 0:
        A = random(N).astype(FFT.float)
        if FFT.communication == 'AlltoallN':
            C = empty(FFT.global_complex_shape(), dtype=FFT.complex)
            C = rfftn(A, C, axes=(0,1,2))
            C[:, :, -1] = 0  # Remove Nyquist frequency
            A = irfftn(C, A, axes=(0,1,2))
        B2 = zeros(FFT.global_complex_shape(), dtype=FFT.complex)
        B2 = rfftn(A, B2, axes=(0,1,2))

    else:
        A = zeros(N, dtype=FFT.float)
        B2 = zeros(FFT.global_complex_shape(), dtype=FFT.complex)

    atol, rtol = (1e-10, 1e-8) if FFT.float is float64 else (5e-7, 1e-4)
    FFT.comm.Bcast(A, root=0)
    FFT.comm.Bcast(B2, root=0)

    a = zeros(FFT.real_shape(), dtype=FFT.float)
    c = zeros(FFT.complex_shape(), dtype=FFT.complex)
    a[:] = A[FFT.real_local_slice()]
    c = FFT.fftn(a, c)
    #print abs((c - B2[FFT.complex_local_slice()])/c.max()).max()
    assert all(abs((c - B2[FFT.complex_local_slice()])/c.max()) < rtol)
    #assert allclose(c, B2[FFT.complex_local_slice()], rtol, atol)
    a = FFT.ifftn(c, a)
    #print abs((a - A[FFT.real_local_slice()])/a.max()).max()

    assert all(abs((a - A[FFT.real_local_slice()])/a.max()) < rtol)
    #assert allclose(a, A[FFT.real_local_slice()], rtol, atol)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号