Clustering.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:GRIPy 作者: giruenf 项目源码 文件源码
def expectation_maximization(data, nc, cv_type='full', req_info=None):
    gmm = GMM(n_components=nc, covariance_type=cv_type, thresh=1.0E-4, n_init=10)
    gmm.fit(data)

    labels = gmm.predict(data)

    if req_info == 'all':
        req_info = ['aic', 'bic', 'converged', 'weights', 'means', 'covars',
                    'silhouette', 'proba']
    elif req_info is None:
        req_info = []

    info = {}
    if 'aic' in req_info:
        info['aic'] = gmm.aic(data)
    if 'bic' in req_info:
        info['bic'] = gmm.bic(data)
    if 'converged' in req_info:
        info['converged'] = gmm.converged_
    if 'weights' in req_info:
        info['weights'] = gmm.weights_
    if 'means' in req_info:
        info['means'] = gmm.means_
    if 'covars' in req_info:
        if cv_type == 'full':
            info['covars'] = gmm.covars_
        elif cv_type == 'tied':
            cov = np.empty((nc, gmm.covars_.shape[0], gmm.covars_.shape[1]))
            for i in range(nc):
                cov[i] = gmm.covars_.copy()
            info['covars'] = cov
        else:
            cov = np.empty((nc, gmm.covars_.shape[0], gmm.covars_.shape[1]))
            for i in range(nc):
                cov[i] = np.diag(gmm.covars_[i])
            info['covars'] = cov
    if 'silhouette' in req_info:
        info['silhouette'] = metrics.silhouette_score(data, labels)
    if 'proba' in req_info:
        info['proba'] = gmm.predict_proba(data).T

    return labels, info
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号