def fft_convolve_1d(x: np.ndarray, h: np.ndarray):
n = len(x) + len(h) - 1
n_opt = 1 << (n - 1).bit_length() # Get next power of 2
if np.issubdtype(x.dtype, np.complexfloating) or np.issubdtype(h.dtype, np.complexfloating):
fft, ifft = np.fft.fft, np.fft.ifft # use complex fft
else:
fft, ifft = np.fft.rfft, np.fft.irfft # use real fft
result = ifft(fft(x, n_opt) * fft(h, n_opt), n_opt)[0:n]
too_much = (len(result) - len(x)) // 2 # Center result
return result[too_much: -too_much]
评论列表
文章目录