def multiprocessing_grid_search(queue, shared_list, persistent_object):
"""Explore cross validation grid using multiprocessing."""
# scores = cross_val_score(*cross_val_score_args, **cross_val_score_kwargs)
# queue.put(scores)
while True:
# All parameters from cross_val_score, i to compute pickle name and
# persistent_path.
passed_parameters = queue.get()
if passed_parameters is None:
break
# Dismember arguments and values.
grid, cvs_args, cvs_kwargs = passed_parameters
estimator, x = cvs_args
estimator.set_params(**grid)
del cvs_args
# Check if value was already calculated:
stored_value = persistent_object.retrieve(estimator, grid)
if stored_value is None:
scores = cross_val_score(estimator, x, **cvs_kwargs)
persistent_object.update(estimator, grid, scores)
else:
scores = stored_value
grid_result = grid.copy()
grid_result['scores'] = scores
shared_list.append(grid_result)
评论列表
文章目录