def cluster_meanshift(X_train, model_args=None, gridsearch=True, estimate_bandwidth_samples=None):
from sklearn.cluster import MeanShift, estimate_bandwidth
print('MeanShift')
if gridsearch is True:
## TODO:
# add hyperparamter searching. No scoring method available for this model,
# so we can't easily use gridsearching.
raise NotImplementedError('No hyperparameter optimization available yet for this model. Set gridsearch to False')
# prune(param_grid, model_args)
else:
param_grid = None
if 'bandwidth' not in model_args:
print('Calculating the bandwidth')
bandwidth = estimate_bandwidth(X_train, n_samples=estimate_bandwidth_samples)
model_args['bandwidth'] = bandwidth
return ModelWrapper(MeanShift, X=X_train, model_args=model_args, param_grid=param_grid, unsupervised=True)
评论列表
文章目录