def foldsplitter(taskcolumn, train_set_sizes):
'''
For each task id (in passed taskcolumn) take rows from number
train_set_sizes up for testing,
and all other rows for training (so training consists of both other
task ids and of rows from the same task id
up to number train_set_sizes-1.
'''
folds = sorted(list(set(taskcolumn)))
for fold in folds:
for train_set_size in train_set_sizes:
testfold2train = taskcolumn == fold
cnt = 0
for (i, x) in enumerate(testfold2train):
if testfold2train[i]:
cnt += 1
if cnt > train_set_size:
testfold2train[i] = False
remaining_train = taskcolumn != fold
x = np.logical_or.reduce([testfold2train, remaining_train])
yield (x, np.logical_not(x))
评论列表
文章目录