clustering.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:PyFusionGUI 作者: SyntaxVoid 项目源码 文件源码
def EM_GMM_clustering(instance_array, n_clusters=9, sin_cos = 0, number_of_starts = 10, show_covariances = 0, clim=None, covariance_type='diag'):
    print 'starting EM-GMM algorithm from sckit-learn, k=%d, retries : %d, sin_cos = %d'%(n_clusters,number_of_starts,sin_cos)
    if sin_cos==1:
        print '  using sine and cosine of the phases'
        sin_cos_instances = np.zeros((instance_array.shape[0],instance_array.shape[1]*2),dtype=float)
        sin_cos_instances[:,::2]=np.cos(instance_array)
        sin_cos_instances[:,1::2]=np.sin(instance_array)
        input_data = sin_cos_instances
    else:
        print '  using raw phases'
        input_data = instance_array
    gmm = mixture.GMM(n_components=n_clusters,covariance_type=covariance_type,n_init=number_of_starts)
    gmm.fit(input_data)
    cluster_assignments = gmm.predict(input_data)
    bic_value = gmm.bic(input_data)
    LL = np.sum(gmm.score(input_data))
    gmm_covars_tmp = np.array(gmm._get_covars())
    if show_covariances:
        fig, ax = make_grid_subplots(gmm_covars_tmp.shape[0], sharex = True, sharey = True)
        im = []
        for i in range(gmm_covars_tmp.shape[0]):
            im.append(ax[i].imshow(np.abs(gmm_covars_tmp[i,:,:]),aspect='auto'))
            print im[-1].get_clim()
            if clim==None:
                im[-1].set_clim([0, im[-1].get_clim()[1]*0.5])
            else:
                im[-1].set_clim(clim)
        clims = [np.min(np.abs(gmm_covars_tmp)),np.max(np.abs(gmm_covars_tmp))*0.5]
        #for i in im : i.set_clim(clims)
        fig.subplots_adjust(hspace=0, wspace=0,left=0.05, bottom=0.05,top=0.95, right=0.95)
        fig.canvas.draw();fig.show()

    gmm_covars = np.array([np.diagonal(i) for i in gmm._get_covars()])
    gmm_means = gmm.means_
    if sin_cos:
        cluster_details = {'EM_GMM_means_sc':gmm_means, 'EM_GMM_variances_sc':gmm_covars, 'EM_GMM_covariances_sc':gmm_covars_tmp,'BIC':bic_value, 'LL':LL}
    else:
        cluster_details = {'EM_GMM_means':gmm_means, 'EM_GMM_variances':gmm_covars, 'EM_GMM_covariances':gmm_covars_tmp, 'BIC':bic_value,'LL':LL}
    return cluster_assignments, cluster_details
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号