cross_validation.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:ottertune 作者: cmu-db 项目源码 文件源码
def fit(self, X, y):
        #import traceback
        from fabric.api import local

        X, y = check_X_y(X, y, allow_nd=True, multi_output=True,
                         y_numeric=True, estimator="GridSearch")
        print "njobs = {}".format(self.njobs)
        if self.njobs > 1:
            assert False
#             iterable = [(i, pg, self.estimator_cls, self.kf, X, y, \
#                          self.score_fns, len(self.parameter_grid)) \
#                          for i,pg in enumerate(self.parameter_grid)]
#             try:
#                 p = multiprocessing.Pool(self.njobs)
#                 res = p.map(mp_grid_search, iterable)
#                 print res
#             except:
#                 traceback.print_exc()
        else:
            self.grid_scores = []
            estimator = self.estimator_cls()
            num_tasks = len(self.parameter_grid)
            for i,params in enumerate(self.parameter_grid):
                print "Starting task {}/{}...".format(i+1, num_tasks)
                with stopwatch("Done. Elapsed time"):
                    self.grid_scores.append(mp_grid_search((i,
                                                           params,
                                                           estimator,
                                                           self.kf,
                                                           X,
                                                           y,
                                                           self.score_fns,
                                                           len(self.parameter_grid))))

                if self.checkpoint_path is not None:
                    local("rm -f {}*.p".format(self.checkpoint_path))
                    savepath = self.checkpoint_path + "_{}.p".format(i)
                    with open(savepath, 'w') as f:
                        pickle.dump(self.grid_scores, f)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号