def cluster_agglomerative(X_train, model_args=None, gridsearch=True, connectivity_graph=True, connectivity_graph_neighbors=10):
from sklearn.cluster import AgglomerativeClustering
from sklearn.neighbors import kneighbors_graph
print('AgglomerativeClustering')
if connectivity_graph:
print('Creating k-neighbors graph for connectivity restraint')
connectivity = kneighbors_graph(X_train, n_neighbors=connectivity_graph_neighbors)
model_args['connectivity'] = connectivity
if gridsearch is True:
## TODO:
# add hyperparamter searching. No scoring method available for this model,
# so we can't easily use gridsearching.
raise NotImplementedError('No hyperparameter optimization available yet for this model. Set gridsearch to False')
# prune(param_grid, model_args)
else:
if 'n_clusters' not in model_args:
raise KeyError('Need to define n_clusters for AgglomerativeClustering')
param_grid = None
return ModelWrapper(AgglomerativeClustering, X=X_train, model_args=model_args, param_grid=param_grid, unsupervised=True)
评论列表
文章目录