classify.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Lyssandra 作者: ektormak 项目源码 文件源码
def __call__(self, X, y):
        """
        given a dataset X,y we split it, in order to do cross validation,
        according to the procedure explained below:
        if n_folds is not None, then we do cross validation
        based on stratified folds
        if n_class_samples is not None, then we do cross validation
        using only <n_class_samples> training samples per class
        if n_test_samples is not None, then we do cross validation
        using only <n_test_samples> cross validaition samples per class
        assumes that each datapoint is in a column of X
        """
        n_classes = len(set(y))
        if self.n_folds is not None:
            # generate the folds
            self.folds = StratifiedKFold(y, n_folds=self.n_folds,
                                         shuffle=False, random_state=None)

        elif self.n_class_samples is not None:

            self.folds = []
            for i in range(self.n_tests):

                if type(self.n_class_samples) is not list:
                    self.n_class_samples = (np.ones(n_classes) * self.n_class_samples).astype(int)
                if self.n_test_samples is not None:
                    self.n_test_samples = (np.ones(n_classes) * self.n_test_samples).astype(int)

                data_idx = split_dataset(self.n_class_samples, self.n_test_samples, y)
                train_idx = data_idx[0]
                test_idx = data_idx[1]
                self.folds.append((train_idx, test_idx))

        self.cross_validate(X, y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号