_split.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def _make_test_folds(self, X, y=None, labels=None):
        if self.shuffle:
            rng = check_random_state(self.random_state)
        else:
            rng = self.random_state
        y = np.asarray(y)
        n_samples = y.shape[0]
        unique_y, y_inversed = np.unique(y, return_inverse=True)
        y_counts = bincount(y_inversed)
        min_labels = np.min(y_counts)
        if np.all(self.n_folds > y_counts):
            raise ValueError("All the n_labels for individual classes"
                             " are less than %d folds."
                             % (self.n_folds))
        if self.n_folds > min_labels:
            warnings.warn(("The least populated class in y has only %d"
                           " members, which is too few. The minimum"
                           " number of labels for any class cannot"
                           " be less than n_folds=%d."
                           % (min_labels, self.n_folds)), Warning)

        # pre-assign each sample to a test fold index using individual KFold
        # splitting strategies for each class so as to respect the balance of
        # classes
        # NOTE: Passing the data corresponding to ith class say X[y==class_i]
        # will break when the data is not 100% stratifiable for all classes.
        # So we pass np.zeroes(max(c, n_folds)) as data to the KFold
        per_cls_cvs = [
            KFold(self.n_folds, shuffle=self.shuffle,
                  random_state=rng).split(np.zeros(max(count, self.n_folds)))
            for count in y_counts]

        test_folds = np.zeros(n_samples, dtype=np.int)
        for test_fold_indices, per_cls_splits in enumerate(zip(*per_cls_cvs)):
            for cls, (_, test_split) in zip(unique_y, per_cls_splits):
                cls_test_folds = test_folds[y == cls]
                # the test split can be too big because we used
                # KFold(...).split(X[:max(c, n_folds)]) when data is not 100%
                # stratifiable for all the classes
                # (we use a warning instead of raising an exception)
                # If this is the case, let's trim it:
                test_split = test_split[test_split < len(cls_test_folds)]
                cls_test_folds[test_split] = test_fold_indices
                test_folds[y == cls] = cls_test_folds

        return test_folds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号