def cwt_freq(data, wavelet, widths, dt, axis):
# compute in frequency
# next highest power of two for padding
N = data.shape[axis]
pN = int(2 ** np.ceil(np.log2(N)))
# N.B. padding in fft adds zeros to the *end* of the array,
# not equally either end.
fft_data = scipy.fft(data, n=pN, axis=axis)
# frequencies
w_k = np.fft.fftfreq(pN, d=dt) * 2 * np.pi
# sample wavelet and normalise
norm = (2 * np.pi * widths / dt) ** .5
wavelet_data = norm[:, None] * wavelet(w_k, widths[:, None])
# Convert negative axis. Add one to account for
# inclusion of widths axis above.
axis = (axis % data.ndim) + 1
# perform the convolution in frequency space
slices = [slice(None)] + [None for _ in data.shape]
slices[axis] = slice(None)
out = scipy.ifft(fft_data[None] * wavelet_data.conj()[slices],
n=pN, axis=axis)
# remove zero padding
slices = [slice(None) for _ in out.shape]
slices[axis] = slice(None, N)
if data.ndim == 1:
return out[slices].squeeze()
else:
return out[slices]
评论列表
文章目录