def get_algorithm(estimator,
scoring_metric,
hyperparameter_grid,
randomized_search,
number_iteration_samples=10,
**non_randomized_estimator_kwargs):
"""
Given an estimator and various params, initialize an algorithm with optional randomized search.
Args:
estimator (sklearn.base.BaseEstimator): a scikit-learn estimator (for example: KNeighborsClassifier)
scoring_metric (str): The scoring metric to optimized for if using random search. See
http://scikit-learn.org/stable/modules/model_evaluation.html
hyperparameter_grid (dict): An object containing key value pairs of the specific hyperparameter space to search
through.
randomized_search (bool): Whether the method should return a randomized search estimator (as opposed to a
simple algorithm).
number_iteration_samples (int): If performing randomized search, this is the number of samples that are run in
the hyperparameter space. Higher numbers will be slower, but end up with better results, since it is more
likely that the true optimal hyperparameter is found.
**non_randomized_estimator_kwargs: Keyword arguments that you can pass directly to the algorithm. Only used when
radomized_search is False
Returns:
sklearn.base.BaseEstimator: a scikit learn algorithm ready to `.fit()`
"""
if randomized_search:
algorithm = RandomizedSearchCV(estimator=estimator(),
scoring=scoring_metric,
param_distributions=hyperparameter_grid,
n_iter=number_iteration_samples,
verbose=0,
n_jobs=1)
else:
algorithm = estimator(**non_randomized_estimator_kwargs)
return algorithm
randomized_search.py 文件源码
python
阅读 69
收藏 0
点赞 0
评论 0
评论列表
文章目录