randomized_search.py 文件源码

python
阅读 69 收藏 0 点赞 0 评论 0

项目:healthcareai-py 作者: HealthCatalyst 项目源码 文件源码
def get_algorithm(estimator,
                  scoring_metric,
                  hyperparameter_grid,
                  randomized_search,
                  number_iteration_samples=10,
                  **non_randomized_estimator_kwargs):
    """
    Given an estimator and various params, initialize an algorithm with optional randomized search.

    Args:
        estimator (sklearn.base.BaseEstimator): a scikit-learn estimator (for example: KNeighborsClassifier)
        scoring_metric (str): The scoring metric to optimized for if using random search. See
            http://scikit-learn.org/stable/modules/model_evaluation.html
        hyperparameter_grid (dict): An object containing key value pairs of the specific hyperparameter space to search
            through.
        randomized_search (bool): Whether the method should return a randomized search estimator (as opposed to a
            simple algorithm).
        number_iteration_samples (int): If performing randomized search, this is the number of samples that are run in 
            the hyperparameter space. Higher numbers will be slower, but end up with better results, since it is more
            likely that the true optimal hyperparameter is found.
        **non_randomized_estimator_kwargs: Keyword arguments that you can pass directly to the algorithm. Only used when
            radomized_search is False

    Returns:
        sklearn.base.BaseEstimator: a scikit learn algorithm ready to `.fit()`

    """
    if randomized_search:
        algorithm = RandomizedSearchCV(estimator=estimator(),
                                       scoring=scoring_metric,
                                       param_distributions=hyperparameter_grid,
                                       n_iter=number_iteration_samples,
                                       verbose=0,
                                       n_jobs=1)

    else:
        algorithm = estimator(**non_randomized_estimator_kwargs)

    return algorithm
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号