def _get_epochs_interpolation(self, epochs, drop_log,
ch_type, verbose='progressbar'):
"""Interpolate the bad epochs."""
# 1: bad segment, # 2: interpolated
fix_log = drop_log.copy()
ch_names = epochs.ch_names
non_picks = np.setdiff1d(range(epochs.info['nchan']), self.picks)
interp_channels = list()
n_interpolate = self.n_interpolate[ch_type]
for epoch_idx in range(len(epochs)):
n_bads = drop_log[epoch_idx, self.picks].sum()
if n_bads == 0:
continue
else:
if n_bads <= n_interpolate:
interp_chs_mask = drop_log[epoch_idx] == 1
else:
# get peak-to-peak for channels in that epoch
data = epochs[epoch_idx].get_data()[0]
peaks = np.ptp(data, axis=-1)
peaks[non_picks] = -np.inf
# find channels which are bad by rejection threshold
interp_chs_mask = drop_log[epoch_idx] == 1
# ignore good channels
peaks[~interp_chs_mask] = -np.inf
# find the ordering of channels amongst the bad channels
sorted_ch_idx_picks = np.argsort(peaks)[::-1]
# then select only the worst n_interpolate channels
interp_chs_mask[
sorted_ch_idx_picks[n_interpolate:]] = False
fix_log[epoch_idx][interp_chs_mask] = 2
interp_chs = np.where(interp_chs_mask)[0]
interp_chs = [ch_name for idx, ch_name in enumerate(ch_names)
if idx in interp_chs]
interp_channels.append(interp_chs)
return interp_channels, fix_log
评论列表
文章目录