cuda.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:decoding_challenge_cortana_2016_3rd 作者: kingjr 项目源码 文件源码
def fft_multiply_repeated(h_fft, x, cuda_dict=dict(use_cuda=False)):
    """Do FFT multiplication by a filter function (possibly using CUDA)

    Parameters
    ----------
    h_fft : 1-d array or gpuarray
        The filtering array to apply.
    x : 1-d array
        The array to filter.
    cuda_dict : dict
        Dictionary constructed using setup_cuda_multiply_repeated().

    Returns
    -------
    x : 1-d array
        Filtered version of x.
    """
    if not cuda_dict['use_cuda']:
        # do the fourier-domain operations
        x = np.real(ifft(h_fft * fft(x), overwrite_x=True)).ravel()
    else:
        cudafft = _get_cudafft()
        # do the fourier-domain operations, results in second param
        cuda_dict['x'].set(x.astype(np.float64))
        cudafft.fft(cuda_dict['x'], cuda_dict['x_fft'], cuda_dict['fft_plan'])
        _multiply_inplace_c128(h_fft, cuda_dict['x_fft'])
        # If we wanted to do it locally instead of using our own kernel:
        # cuda_seg_fft.set(cuda_seg_fft.get() * h_fft)
        cudafft.ifft(cuda_dict['x_fft'], cuda_dict['x'],
                     cuda_dict['ifft_plan'], False)
        x = np.array(cuda_dict['x'].get(), dtype=x.dtype, subok=True,
                     copy=False)
    return x


###############################################################################
# FFT Resampling
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号