clustering.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:PyFusionGUI 作者: SyntaxVoid 项目源码 文件源码
def __init__(self, instance_array_amps, n_clusters = 9, n_iterations = 20, n_cpus=1, start='random', kappa_calc='approx', hard_assignments = 0, kappa_converged = 0.1, mu_converged = 0.01, min_iterations=10, LL_converged = 1.e-4, verbose = 0, seed=None, norm_method = 'sum'):
        print 'EM_GMM_GMM2', instance_array_amps.shape
        self.settings = {'n_clusters':n_clusters,'n_iterations':n_iterations,'n_cpus':n_cpus,'start':start,
                         'kappa_calc':kappa_calc,'hard_assignments':hard_assignments, 'method':'EM_VMM_GMM'}
        #self.instance_array = copy.deepcopy(instance_array)
        self.instance_array_amps = instance_array_amps
        self.data_complex = norm_bet_chans(instance_array_amps, method = norm_method)
        print 'hello norm method',  norm_method
        self.data_complex = instance_array_amps/np.sum(instance_array_amps,axis = 1)[:,np.newaxis]
        self.input_data = np.hstack((np.real(self.data_complex), np.imag(self.data_complex)))
        self.n_dim = self.data_complex.shape[1]
        self.n_instances, self.n_dimensions = self.input_data.shape

        self.n_clusters = n_clusters; self.max_iterations = n_iterations; self.start = start
        self.hard_assignments = hard_assignments; self.seed = seed
        if self.seed == None: self.seed = os.getpid()
        print('seed,',self.seed)
        np.random.seed(self.seed)
        self.iteration = 1

        self._initialisation()
        self.convergence_record = []; converged = 0 
        self.LL_diff = np.inf
        while converged!=1:
            start_time = time.time()
            self._EM_VMM_GMM_expectation_step()
            if self.hard_assignments:
                print 'hard assignments'
                self.cluster_assignments = np.argmax(self.zij,axis=1)
                self.zij = self.zij *0
                for i in range(self.n_clusters):
                    self.zij[self.cluster_assignments==i,i] = 1

            valid_items = self.probs>(1.e-300)
            self.LL_list.append(np.sum(self.zij[valid_items]*np.log(self.probs[valid_items])))
            self._EM_VMM_GMM_maximisation_step()
            if (self.iteration>=2): self.LL_diff = np.abs(((self.LL_list[-1] - self.LL_list[-2])/self.LL_list[-2]))
            if verbose:
                print 'Time for iteration %d :%.2f, mu_convergence:%.3e, kappa_convergence:%.3e, LL: %.8e, LL_dif : %.3e'%(self.iteration,time.time() - start_time,self.convergence_mean, self.convergence_std, self.LL_list[-1],self.LL_diff)
            self.convergence_record.append([self.iteration, self.convergence_mean, self.convergence_std])
            mean_converged = mu_converged; std_converged = kappa_converged
            if (self.iteration > min_iterations) and (self.convergence_mean<mean_converged) and (self.convergence_std<std_converged) and (self.LL_diff<LL_converged):
                converged = 1
                print 'Convergence criteria met!!'
            elif self.iteration > n_iterations:
                converged = 1
                print 'Max number of iterations'
            self.iteration+=1
        print os.getpid(), 'Time for iteration %d :%.2f, mu_convergence:%.3e, kappa_convergence:%.3e, LL: %.8e, LL_dif : %.3e'%(self.iteration,time.time() - start_time,self.convergence_mean, self.convergence_std,self.LL_list[-1],self.LL_diff)
        #print 'AIC : %.2f'%(2*(mu_list.shape[0]*mu_list.shape[1])-2.*LL_list[-1])
        self.cluster_assignments = np.argmax(self.zij,axis=1)
        self.BIC = -2*self.LL_list[-1]+self.n_clusters*3*np.log(self.n_dimensions)
        gmm_means_re, gmm_means_im = np.hsplit(self.mean_list, 2)
        gmm_vars_re, gmm_vars_im = np.hsplit(self.std_list**2, 2)

        self.cluster_details = {'EM_GMM_means':self.mean_list, 'EM_GMM_variances':self.std_list**2,'BIC':self.BIC,'LL':self.LL_list, 
                                'EM_GMM_means_re':gmm_means_re, 'EM_GMM_variances_re':gmm_vars_re,
                                'EM_GMM_means_im':gmm_means_im, 'EM_GMM_variances_im':gmm_vars_im}
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号