def split(self, X, y, groups=None):
splits = super(BalancedKFold, self).split(X, y, groups)
y = np.array(y)
for train_index, test_index in splits:
split_y = y[test_index]
classes_y, y_inversed = np.unique(split_y, return_inverse=True)
min_y = min(np.bincount(y_inversed))
new_index = np.zeros(min_y * len(classes_y), dtype=int)
for cls in classes_y:
cls_index = test_index[split_y == cls]
if len(cls_index) > min_y:
cls_index = np.random.choice(
cls_index, size=min_y, replace=False)
new_index[cls * min_y:(cls + 1) * min_y] = cls_index
yield train_index, new_index
评论列表
文章目录