def fftconvolve(in1, in2, mode="full", axis=None):
""" Convolve two N-dimensional arrays using FFT. See convolve.
This is a fix of scipy.signal.fftconvolve, adding an axis argument and
importing locally the stuff only needed for this function
"""
s1 = np.array(in1.shape)
s2 = np.array(in2.shape)
complex_result = (np.issubdtype(in1.dtype, np.complex) or
np.issubdtype(in2.dtype, np.complex))
if axis is None:
size = s1 + s2 - 1
fslice = tuple([slice(0, int(sz)) for sz in size])
else:
equal_shapes = s1 == s2
# allow equal_shapes[axis] to be False
equal_shapes[axis] = True
assert equal_shapes.all(), 'Shape mismatch on non-convolving axes'
size = s1[axis] + s2[axis] - 1
fslice = [slice(l) for l in s1]
fslice[axis] = slice(0, int(size))
fslice = tuple(fslice)
# Always use 2**n-sized FFT
fsize = 2 ** int(np.ceil(np.log2(size)))
if axis is None:
IN1 = fftpack.fftn(in1, fsize)
IN1 *= fftpack.fftn(in2, fsize)
ret = fftpack.ifftn(IN1)[fslice].copy()
else:
IN1 = fftpack.fft(in1, fsize, axis=axis)
IN1 *= fftpack.fft(in2, fsize, axis=axis)
ret = fftpack.ifft(IN1, axis=axis)[fslice].copy()
del IN1
if not complex_result:
ret = ret.real
if mode == "full":
return ret
elif mode == "same":
if np.product(s1, axis=0) > np.product(s2, axis=0):
osize = s1
else:
osize = s2
return signaltools._centered(ret, osize)
elif mode == "valid":
return signaltools._centered(ret, abs(s2 - s1) + 1)
评论列表
文章目录