Classification.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:UVA 作者: chiachun 项目源码 文件源码
def run(self):
        cfg = self.cfg
        self.prepare()

        # Run KMean clustering. The resulted cluster centers
        # will be used as seeds for the later MeanShift clustering, which will
        # split the KMean clusters into subclusters if MeanShift find subgroups.  
        n_clusters = len(self.dfp)/cfg.avg_clsize
        labels, centers = self.run_kmean(n_clusters)
        self.dfp['label1'] = labels
        kvals = np.unique(self.dfp.label1.values)

        # Use the largest kmean group to estimate MeanShift bandwidth
        idxmax = self.dfp.label1.value_counts().idxmax()
        df_ = self.dfp.loc[self.dfp['label1']==idxmax]
        xp_ = df_[self.pccols].as_matrix()
        bandwidth = estimate_bandwidth(xp_, quantile=0.3)

        # run mean shift using centers found by KMmean 
        ms = MeanShift(bandwidth=bandwidth, seeds=centers,
                       cluster_all=True)
        xp = self.dfp[self.pccols].as_matrix()
        ms.fit(xp)        
        mslabels_unique = np.unique(ms.labels_)
        nc = len(mslabels_unique)

        # run kmean again using number of clusters found by MeanShift
        labels, centers = self.run_kmean(nc)
        self.dfp['label1'] = labels
        kvals = np.unique(self.dfp['label1'].values)
        print "Classes after the second Kmean: ", kvals

        # run mean_shift to analyze KMean clusters 
        # Samples classified as other clusters are assigned new labels
        # New classes whose counts pass the minimum threshold will
        # be kept in the analysis chain, which don't pass will be ignored.
        for kval in kvals:
           __,__, bandwidth = mean_shift(self.dfp, kval, kval, 'label1',
                                         0.3, True, False)
        print "Classification result before merging"
        print "class  counts"
        print self.dfp['label1'].value_counts() 
        # count cut
        cnts = self.dfp['label1'].value_counts()
        passed_cnts = cnts[ cnts>self.min_counts ].index.tolist()
        self.dfp = self.dfp[self.dfp['label1'].isin(passed_cnts)]

        self.mean_shift_merge('label')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号