gmm.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:cupy 作者: cupy 项目源码 文件源码
def train_gmm(X, max_iter, tol, means, covariances):
    xp = cupy.get_array_module(X)
    lower_bound = -np.infty
    converged = False
    weights = xp.array([0.5, 0.5], dtype=np.float32)
    inv_cov = 1 / xp.sqrt(covariances)

    for n_iter in six.moves.range(max_iter):
        prev_lower_bound = lower_bound
        log_prob_norm, log_resp = e_step(X, inv_cov, means, weights)
        weights, means, covariances = m_step(X, xp.exp(log_resp))
        inv_cov = 1 / xp.sqrt(covariances)
        lower_bound = log_prob_norm
        change = lower_bound - prev_lower_bound
        if abs(change) < tol:
            converged = True
            break

    if not converged:
        print('Failed to converge. Increase max-iter or tol.')

    return inv_cov, means, weights, covariances
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号