srm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:brainiak 作者: brainiak 项目源码 文件源码
def fit(self, X, y=None):
        """Compute the probabilistic Shared Response Model

        Parameters
        ----------
        X :  list of 2D arrays, element i has shape=[voxels_i, samples]
            Each element in the list contains the fMRI data of one subject.

        y : not used
        """
        logger.info('Starting Probabilistic SRM')

        # Check the number of subjects
        if len(X) <= 1:
            raise ValueError("There are not enough subjects "
                             "({0:d}) to train the model.".format(len(X)))

        # Check for input data sizes
        number_subjects = len(X)
        number_subjects_vec = self.comm.allgather(number_subjects)
        for rank in range(self.comm.Get_size()):
            if number_subjects_vec[rank] != number_subjects:
                raise ValueError(
                    "Not all ranks have same number of subjects")

        # Collect size information
        shape0 = np.zeros((number_subjects,), dtype=np.int)
        shape1 = np.zeros((number_subjects,), dtype=np.int)

        for subject in range(number_subjects):
            if X[subject] is not None:
                assert_all_finite(X[subject])
                shape0[subject] = X[subject].shape[0]
                shape1[subject] = X[subject].shape[1]

        shape0 = self.comm.allreduce(shape0, op=MPI.SUM)
        shape1 = self.comm.allreduce(shape1, op=MPI.SUM)

        # Check if all subjects have same number of TRs
        number_trs = np.min(shape1)
        for subject in range(number_subjects):
            if shape1[subject] < self.features:
                raise ValueError(
                    "There are not enough samples to train the model with "
                    "{0:d} features.".format(self.features))
            if shape1[subject] != number_trs:
                raise ValueError("Different number of samples between subjects"
                                 ".")
        # Run SRM
        self.sigma_s_, self.w_, self.mu_, self.rho2_, self.s_ = self._srm(X)

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号