def check_parameters_default_constructible(name, Estimator):
classifier = LinearDiscriminantAnalysis()
# test default-constructibility
# get rid of deprecation warnings
with warnings.catch_warnings(record=True):
if name in META_ESTIMATORS:
estimator = Estimator(classifier)
else:
estimator = Estimator()
# test cloning
clone(estimator)
# test __repr__
repr(estimator)
# test that set_params returns self
assert_true(estimator.set_params() is estimator)
# test if init does nothing but set parameters
# this is important for grid_search etc.
# We get the default parameters from init and then
# compare these against the actual values of the attributes.
# this comes from getattr. Gets rid of deprecation decorator.
init = getattr(estimator.__init__, 'deprecated_original',
estimator.__init__)
try:
def param_filter(p):
"""Identify hyper parameters of an estimator"""
return (p.name != 'self'
and p.kind != p.VAR_KEYWORD
and p.kind != p.VAR_POSITIONAL)
init_params = [p for p in signature(init).parameters.values()
if param_filter(p)]
except (TypeError, ValueError):
# init is not a python function.
# true for mixins
return
params = estimator.get_params()
if name in META_ESTIMATORS:
# they can need a non-default argument
init_params = init_params[1:]
for init_param in init_params:
assert_not_equal(init_param.default, init_param.empty,
"parameter %s for %s has no default value"
% (init_param.name, type(estimator).__name__))
assert_in(type(init_param.default),
[str, int, float, bool, tuple, type(None),
np.float64, types.FunctionType, Memory])
if init_param.name not in params.keys():
# deprecated parameter, not in get_params
assert_true(init_param.default is None)
continue
param_value = params[init_param.name]
if isinstance(param_value, np.ndarray):
assert_array_equal(param_value, init_param.default)
else:
assert_equal(param_value, init_param.default)
评论列表
文章目录