subcrpmm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:PyBGMM 作者: junlulocky 项目源码 文件源码
def update_clustering_components(self, mask, assignments):
        cluster_idx = 1
        cluster_D = np.where(mask == cluster_idx)[0].shape[0]
        cluster_X = self.X[:, np.where(mask == cluster_idx)[0]]

        if cluster_D == 1:
            covar_scale = np.var(cluster_X)
        else:
            covar_scale = np.median(LA.eigvals(np.cov(cluster_X.T)))
        mu_scale = np.amax(cluster_X) - covar_scale

        # Intialize prior
        m_0 = cluster_X.mean(axis=0)
        k_0 = 1.0 / self.h1
        # k_0 = covar_scale ** 2 / mu_scale ** 2
        v_0 = cluster_D + 2
        # S_0 = 1./100 / covar_scale * np.eye(cluster_D)
        S_0 = 1. * np.eye(cluster_D)

        cluster_kernel_prior = NIW(m_0, k_0, v_0, S_0)

        if self.covariance_type == "full":
            components = GaussianComponents(cluster_X, cluster_kernel_prior, assignments, self.K_max)
        elif self.covariance_type == "diag":
            components = GaussianComponentsDiag(cluster_X, cluster_kernel_prior, assignments, self.K_max)
        elif self.covariance_type == "fixed":
            components = GaussianComponentsFixedVar(cluster_X, cluster_kernel_prior, assignments, self.K_max)
        else:
            assert False, "Invalid covariance type."

        return components
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号