def train_gmm(X, max_iter, tol, means, covariances):
xp = cupy.get_array_module(X)
lower_bound = -np.infty
converged = False
weights = xp.array([0.5, 0.5], dtype=np.float32)
inv_cov = 1 / xp.sqrt(covariances)
for n_iter in six.moves.range(max_iter):
prev_lower_bound = lower_bound
log_prob_norm, log_resp = e_step(X, inv_cov, means, weights)
weights, means, covariances = m_step(X, xp.exp(log_resp))
inv_cov = 1 / xp.sqrt(covariances)
lower_bound = log_prob_norm
change = lower_bound - prev_lower_bound
if abs(change) < tol:
converged = True
break
if not converged:
print('Failed to converge. Increase max-iter or tol.')
return inv_cov, means, weights, covariances
评论列表
文章目录