def update_common_component(self, mask):
common_idx = 0
common_D = np.where(mask == common_idx)[0].shape[0]
if np.sum(common_D) == 0:
mask[-1] = 0
common_D = np.where(mask == common_idx)[0].shape[0]
common_X = self.X[:, np.where(mask == common_idx)[0]]
if common_D == 1:
covar_scale = np.var(common_X)
else:
covar_scale = np.median(LA.eigvals(np.cov(common_X.T)))
# pass
mu_scale = np.amax(common_X) - covar_scale
m_0 = common_X.mean(axis=0)
k_0 = 1.0 / self.h0
# k_0 = covar_scale**2/mu_scale**2
v_0 = common_D + 2
# S_0 = 1. / covar_scale * np.eye(common_D)
S_0 = 1. * np.eye(common_D)
common_kernel_prior = NIW(m_0, k_0, v_0, S_0)
## save for common component, unused dimensions
common_assignments = np.zeros(common_X.shape[0]) ## one component
if self.common_component_covariance_type == "full":
common_component = GaussianComponents(common_X, common_kernel_prior, common_assignments, 1)
elif self.common_component_covariance_type == "diag":
common_component = GaussianComponentsDiag(common_X, common_kernel_prior, common_assignments, 1)
elif self.common_component_covariance_type == "fixed":
common_component = GaussianComponentsFixedVar(common_X, common_kernel_prior, common_assignments, 1)
else:
assert False, "Invalid covariance type."
return common_component
评论列表
文章目录