gridsearch.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ycml 作者: skylander86 项目源码 文件源码
def fit_binarized(self, X_featurized, Y_binarized, validation_data=None, **kwargs):
        klass = get_class_from_module_path(self.classifier)

        if validation_data is None:  # use 0.2 for validation data
            X_train, X_validation, Y_train, Y_validation = train_test_split(X_featurized, Y_binarized, test_size=self.validation_size)
            logger.info('Using {} of training data ({} instances) for validation.'.format(self.validation_size, Y_validation.shape[0]))
        else:
            X_train, X_validation, Y_train, Y_validation = X_featurized, validation_data[0], Y_binarized, validation_data[1]
        #end if

        best_score, best_param = 0.0, None

        if self.n_jobs > 1: logger.info('Performing hyperparameter gridsearch in parallel using {} jobs.'.format(self.n_jobs))
        else: logger.debug('Performing hyperparameter gridsearch in parallel using {} jobs.'.format(self.n_jobs))

        param_scores = Parallel(n_jobs=self.n_jobs)(delayed(_fit_classifier)(klass, self.classifier_args, param, self.metric, X_train, Y_train, X_validation, Y_validation) for param in ParameterGrid(self.param_grid))

        best_param, best_score = max(param_scores, key=lambda x: x[1])
        logger.info('Best scoring param is {} with score {}.'.format(best_param, best_score))

        classifier_args = {}
        classifier_args.update(self.classifier_args)
        classifier_args.update(best_param)
        self.classifier_ = klass(**classifier_args)
        logger.info('Fitting final model <{}> on full data with param {}.'.format(self.classifier_, best_param))
        self.classifier_.fit(X_featurized, Y_binarized)

        return self
    #end def
#end class
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号