def _create_batches(self):
param_iter = ParameterGrid(self.param_grid)
# divide work into batches equal to the communicator's size
work_batches = [[] for _ in range(comm_size)]
i = 0
for fold_id, (train_index, test_index) in enumerate(self.cv_iter):
for parameters in param_iter:
work_batches[i % comm_size].append((fold_id + 1, train_index,
test_index, parameters))
i += 1
return work_batches
评论列表
文章目录