ica.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:decoding_challenge_cortana_2016_3rd 作者: kingjr 项目源码 文件源码
def _pick_sources(self, data, include, exclude):
        """Aux function"""
        fast_dot = _get_fast_dot()
        if exclude is None:
            exclude = self.exclude
        else:
            exclude = list(set(self.exclude + list(exclude)))

        _n_pca_comp = self._check_n_pca_components(self.n_pca_components)

        if not(self.n_components_ <= _n_pca_comp <= self.max_pca_components):
            raise ValueError('n_pca_components must be >= '
                             'n_components and <= max_pca_components.')

        n_components = self.n_components_
        logger.info('Transforming to ICA space (%i components)' % n_components)

        # Apply first PCA
        if self.pca_mean_ is not None:
            data -= self.pca_mean_[:, None]

        sel_keep = np.arange(n_components)
        if include not in (None, []):
            sel_keep = np.unique(include)
        elif exclude not in (None, []):
            sel_keep = np.setdiff1d(np.arange(n_components), exclude)

        logger.info('Zeroing out %i ICA components'
                    % (n_components - len(sel_keep)))

        unmixing = np.eye(_n_pca_comp)
        unmixing[:n_components, :n_components] = self.unmixing_matrix_
        unmixing = np.dot(unmixing, self.pca_components_[:_n_pca_comp])

        mixing = np.eye(_n_pca_comp)
        mixing[:n_components, :n_components] = self.mixing_matrix_
        mixing = np.dot(self.pca_components_[:_n_pca_comp].T, mixing)

        if _n_pca_comp > n_components:
            sel_keep = np.concatenate(
                (sel_keep, range(n_components, _n_pca_comp)))

        proj_mat = np.dot(mixing[:, sel_keep], unmixing[sel_keep, :])

        data = fast_dot(proj_mat, data)

        if self.pca_mean_ is not None:
            data += self.pca_mean_[:, None]

        # restore scaling
        if self.noise_cov is None:  # revert standardization
            data *= self._pre_whitener
        else:
            data = fast_dot(linalg.pinv(self._pre_whitener), data)

        return data
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号