def spike_filtfilt(data, lowcut=None, highcut=None, *, fs=None, verbose=False):
"""Filter data to the spike band (default 600--6000 Hz).
Parameters
----------
data : AnalogSignalArray, ndarray, or list
lowcut : float, optional (default 600 Hz)
Lower cut-off frequency
highcut : float, optional (default 6000 Hz)
Upper cut-off frequency
fs : float, optional if AnalogSignalArray is passed
Sampling frequency (Hz)
Returns
-------
filtered : same type as data
"""
if isinstance(data, (np.ndarray, list)):
if fs is None:
raise ValueError("sampling frequency must be specified!")
elif isinstance(data, AnalogSignalArray):
if fs is None:
fs = data.fs
if lowcut is None:
lowcut = 600
if highcut is None:
highcut = 6000
[b, a] = butter(2, lowcut/(fs/2), btype='highpass')
[bhigh, ahigh] = butter(1, highcut/(fs/2))
if isinstance(data, (np.ndarray, list)):
# Filter raw data
spikedata = filtfilt(b, a,filtfilt(bhigh, ahigh, data))
return spikedata
elif isinstance(data, AnalogSignalArray):
spikedata = filtfilt(b, a, filtfilt(bhigh, ahigh, data.ydata))
# Return a copy of the AnalogSignalArray with the filtered data
out = copy.copy(data)
out._ydata = spikedata
return out
评论列表
文章目录