test_nfft.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:cuvarbase 作者: johnh2o2 项目源码 文件源码
def test_ffts(self):
        t, tsc, y, err = data()

        yhat = np.empty(len(y))

        yg = gpuarray.to_gpu(y.astype(np.complex128))
        yghat = gpuarray.to_gpu(yhat.astype(np.complex128))

        plan = cufft.Plan(len(y), np.complex128, np.complex128)
        cufft.ifft(yg, yghat, plan)

        yhat = fftpack.ifft(y) * len(y)

        tols = dict(rtol=nfft_rtol, atol=nfft_atol)
        assert_allclose(yhat, yghat.get(), **tols)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号