def fit(self, X, y):
#import traceback
from fabric.api import local
X, y = check_X_y(X, y, allow_nd=True, multi_output=True,
y_numeric=True, estimator="GridSearch")
print "njobs = {}".format(self.njobs)
if self.njobs > 1:
assert False
# iterable = [(i, pg, self.estimator_cls, self.kf, X, y, \
# self.score_fns, len(self.parameter_grid)) \
# for i,pg in enumerate(self.parameter_grid)]
# try:
# p = multiprocessing.Pool(self.njobs)
# res = p.map(mp_grid_search, iterable)
# print res
# except:
# traceback.print_exc()
else:
self.grid_scores = []
estimator = self.estimator_cls()
num_tasks = len(self.parameter_grid)
for i,params in enumerate(self.parameter_grid):
print "Starting task {}/{}...".format(i+1, num_tasks)
with stopwatch("Done. Elapsed time"):
self.grid_scores.append(mp_grid_search((i,
params,
estimator,
self.kf,
X,
y,
self.score_fns,
len(self.parameter_grid))))
if self.checkpoint_path is not None:
local("rm -f {}*.p".format(self.checkpoint_path))
savepath = self.checkpoint_path + "_{}.p".format(i)
with open(savepath, 'w') as f:
pickle.dump(self.grid_scores, f)
评论列表
文章目录