_split.py 文件源码

python
阅读 50 收藏 0 点赞 0 评论 0

项目:mriqc 作者: poldracklab 项目源码 文件源码
def split(self, X, y, groups=None):
        splits = super(BalancedKFold, self).split(X, y, groups)

        y = np.array(y)
        for train_index, test_index in splits:
            split_y = y[test_index]
            classes_y, y_inversed = np.unique(split_y, return_inverse=True)
            min_y = min(np.bincount(y_inversed))
            new_index = np.zeros(min_y * len(classes_y), dtype=int)

            for cls in classes_y:
                cls_index = test_index[split_y == cls]
                if len(cls_index) > min_y:
                    cls_index = np.random.choice(
                        cls_index, size=min_y, replace=False)

                new_index[cls * min_y:(cls + 1) * min_y] = cls_index
            yield train_index, new_index
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号