def cluster_kmeans(X_train, model_args=None, gridsearch=True):
from sklearn.cluster import KMeans
print('KMeans')
if gridsearch is True:
param_grid = {
'n_clusters': np.arange(1, 20, 2),
'max_iter': [50, 100, 300],
'tol': [1e-5, 1e-4, 1e-3]
}
prune(param_grid, model_args)
else:
if 'n_clusters' not in model_args:
raise KeyError('Need to define n_clusters for Birch')
param_grid = None
return ModelWrapper(KMeans, X=X_train, model_args=model_args, param_grid=param_grid, unsupervised=True)
评论列表
文章目录