cv.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pydl 作者: rafaeltg 项目源码 文件源码
def run(self, model, x, y=None, scoring=None, max_threads=1):

        # get scorers
        if scoring is not None:
            if isinstance(scoring, list):
                scorers_fn = dict([(self.get_scorer_name(k), get_scorer(k)) for k in scoring])
            else:
                scorers_fn = dict([(self.get_scorer_name(scoring), get_scorer(scoring))])
        else:
            # By default uses the model loss function as scoring function
            scorers_fn = dict([(model.get_loss_func(), get_scorer(model.get_loss_func()))])

        model_cfg = model.to_json()

        if y is None:
            args = [(model_cfg['model'], train, test, x, scorers_fn) for train, test in self.cv.split(x, y)]
            cv_fn = self._do_unsupervised_cv
        else:
            args = [(model_cfg['model'], train, test, x, y, scorers_fn) for train, test in self.cv.split(x, y)]
            cv_fn = self._do_supervised_cv

        with Parallel(n_jobs=min(max_threads, len(args))) as parallel:
            cv_results = parallel(delayed(function=cv_fn, check_pickle=False)(*a) for a in args)

        return self._consolidate_cv_scores(cv_results)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号