sssrm.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:brainiak 作者: brainiak 项目源码 文件源码
def predict(self, X):
        """Classify the output for given data

        Parameters
        ----------

        X : list of 2D arrays, element i has shape=[voxels_i, samples_i]
            Each element in the list contains the fMRI data of one subject
            The number of voxels should be according to each subject at
            the moment of training the model.

        Returns
        -------
        p: list of arrays, element i has shape=[samples_i]
            Predictions for each data sample.
        """
        # Check if the model exist
        if hasattr(self, 'w_') is False:
            raise NotFittedError("The model fit has not been run yet.")

        # Check the number of subjects
        if len(X) != len(self.w_):
            raise ValueError("The number of subjects does not match the one"
                             " in the model.")

        X_shared = self.transform(X)
        p = [None] * len(X_shared)
        for subject in range(len(X_shared)):
            sumexp, _, exponents = utils.sumexp_stable(
                self.theta_.T.dot(X_shared[subject]) + self.bias_)
            p[subject] = self.classes_[
                (exponents / sumexp[np.newaxis, :]).argmax(axis=0)]

        return p
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号