pyglmnet.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:pyglmnet 作者: glm-tools 项目源码 文件源码
def _set_cv(cv, estimator=None, X=None, y=None):
        """Set the default CV depending on whether clf
           is classifier/regressor."""
        # Detect whether classification or regression
        if estimator in ['classifier', 'regressor']:
            est_is_classifier = estimator == 'classifier'
        else:
            est_is_classifier = is_classifier(estimator)
        # Setup CV
        if check_version('sklearn', '0.18'):
            from sklearn import model_selection as models
            from sklearn.model_selection import (check_cv,
                                                 StratifiedKFold, KFold)
            if isinstance(cv, (int, np.int)):
                XFold = StratifiedKFold if est_is_classifier else KFold
                cv = XFold(n_splits=cv)
            elif isinstance(cv, str):
                if not hasattr(models, cv):
                    raise ValueError('Unknown cross-validation')
                cv = getattr(models, cv)
                cv = cv()
            cv = check_cv(cv=cv, y=y, classifier=est_is_classifier)
        else:
            from sklearn import cross_validation as models
            from sklearn.cross_validation import (check_cv,
                                                  StratifiedKFold, KFold)
            if isinstance(cv, (int, np.int)):
                if est_is_classifier:
                    cv = StratifiedKFold(y=y, n_folds=cv)
                else:
                    cv = KFold(n=len(y), n_folds=cv)
            elif isinstance(cv, str):
                if not hasattr(models, cv):
                    raise ValueError('Unknown cross-validation')
                cv = getattr(models, cv)
                if cv.__name__ not in ['KFold', 'LeaveOneOut']:
                    raise NotImplementedError('CV cannot be defined with str'
                                              ' for sklearn < .017.')
                cv = cv(len(y))
            cv = check_cv(cv=cv, X=X, y=y, classifier=est_is_classifier)

        # Extract train and test set to retrieve them at predict time
        if hasattr(cv, 'split'):
            cv_splits = [(train, test) for train, test in
                         cv.split(X=np.zeros_like(y), y=y)]
        else:
            # XXX support sklearn.cross_validation cv
            cv_splits = [(train, test) for train, test in cv]

        if not np.all([len(train) for train, _ in cv_splits]):
            raise ValueError('Some folds do not have any train epochs.')

        return cv, cv_splits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号