Simple_function.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:vapor 作者: mills-lab 项目源码 文件源码
def k_means_cluster(data_list):
    if max(data_list[0])-min(data_list[0])>10 and max(data_list[1])-min(data_list[1])>10:
        array_diagnal=np.array([[data_list[0][x],data_list[1][x]] for x in range(len(data_list[0]))])
        ks = list(range(1,min([5,len(data_list[0])+1])))
        KMeans = [cluster.KMeans(n_clusters = i, init="k-means++").fit(array_diagnal) for i in ks]
        KMeans_predict=[cluster.KMeans(n_clusters = i, init="k-means++").fit_predict(array_diagnal) for i in ks]
        BIC=[]
        BIC_rec=[]
        for x in ks:
            if KMeans_predict[x-1].max()<x-1: continue
            else:
                BIC_i=compute_bic(KMeans[x-1],array_diagnal)
                if abs(BIC_i)<10**8:
                    BIC.append(BIC_i)
                    BIC_rec.append(x)
        #BIC = [compute_bic(kmeansi,array_diagnal) for kmeansi in KMeans]
        #ks_picked=ks[BIC.index(max(BIC))]
        ks_picked=BIC_rec[BIC.index(max(BIC))]
        if ks_picked==1:
            return [data_list]
        else:
            out=[]
            std_rec=[scipy.std(data_list[0]),scipy.std(data_list[1])]
            whitened = whiten(array_diagnal)
            centroids, distortion=kmeans(whitened,ks_picked)
            idx,_= vq(whitened,centroids)
            for x in range(ks_picked):
                group1=[[int(i) for i in array_diagnal[idx==x,0]],[int(i) for i in array_diagnal[idx==x,1]]]
                out.append(group1)
            return out
    else:
        return [data_list]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号