subcrpmm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:PyBGMM 作者: junlulocky 项目源码 文件源码
def update_common_component(self, mask):
        common_idx = 0
        common_D = np.where(mask == common_idx)[0].shape[0]
        if np.sum(common_D) == 0:
            mask[-1] = 0
            common_D = np.where(mask == common_idx)[0].shape[0]
        common_X = self.X[:, np.where(mask == common_idx)[0]]

        if common_D == 1:
            covar_scale = np.var(common_X)
        else:
            covar_scale = np.median(LA.eigvals(np.cov(common_X.T)))
            # pass
        mu_scale = np.amax(common_X) - covar_scale

        m_0 = common_X.mean(axis=0)
        k_0 = 1.0 / self.h0
        # k_0 = covar_scale**2/mu_scale**2
        v_0 = common_D + 2
        # S_0 = 1. / covar_scale * np.eye(common_D)
        S_0 = 1. * np.eye(common_D)
        common_kernel_prior = NIW(m_0, k_0, v_0, S_0)

        ## save for common component, unused dimensions
        common_assignments = np.zeros(common_X.shape[0])  ## one component

        if self.common_component_covariance_type == "full":
            common_component = GaussianComponents(common_X, common_kernel_prior, common_assignments, 1)
        elif self.common_component_covariance_type == "diag":
            common_component = GaussianComponentsDiag(common_X, common_kernel_prior, common_assignments, 1)
        elif self.common_component_covariance_type == "fixed":
            common_component = GaussianComponentsFixedVar(common_X, common_kernel_prior, common_assignments, 1)
        else:
            assert False, "Invalid covariance type."

        return common_component
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号