test_fft.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:bifrost 作者: ledatelescope 项目源码 文件源码
def run_test_c2r_impl(self, shape, axes, fftshift=False):
        ishape = list(shape)
        oshape = list(shape)
        ishape[axes[-1]] = shape[axes[-1]] // 2 + 1
        oshape[axes[-1]] = (ishape[axes[-1]] - 1) * 2
        ishape[-1] *= 2 # For complex
        known_data = np.random.normal(size=ishape).astype(np.float32).view(np.complex64)
        idata = bf.ndarray(known_data, space='cuda')
        odata = bf.ndarray(shape=oshape, dtype='f32', space='cuda')
        fft = Fft()
        fft.init(idata, odata, axes=axes, apply_fftshift=fftshift)
        fft.execute(idata, odata)
        # Note: Numpy applies normalization while CUFFT does not
        norm = reduce(lambda a, b: a * b, [shape[d] for d in axes])
        if fftshift:
            known_data = np.fft.ifftshift(known_data, axes=axes)
        known_result = gold_irfftn(known_data, axes=axes) * norm
        compare(odata.copy('system'), known_result)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号