def cluster_gaussian(X_train, model_args=None, gridsearch=True):
from sklearn.mixture import GaussianMixture
print('GaussianMixture')
if gridsearch is True:
param_grid = {
'n_components': np.arange(1, 20, 4),
'covariance_type': ['full', 'tied', 'diag', 'spherical'],
}
prune(param_grid, model_args)
else:
param_grid = None
return ModelWrapper(GaussianMixture, X=X_train, model_args=model_args, param_grid=param_grid, unsupervised=True)
评论列表
文章目录