mixture_model.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:prml 作者: Yevgnen 项目源码 文件源码
def fit(self, X):
        n_samples, n_features = X.shape
        n_classes = self.n_classes
        max_iter = self.max_iter
        tol = self.tol

        rand_center_idx = sprand.permutation(n_samples)[0:n_classes]
        center = X[rand_center_idx].T
        responsilibity = sp.zeros((n_samples, n_classes))

        for iter in range(max_iter):
            # E step
            dist = sp.expand_dims(X, axis=2) - sp.expand_dims(center, axis=0)
            dist = spla.norm(dist, axis=1)**2
            min_idx = sp.argmin(dist, axis=1)
            responsilibity.fill(0)
            responsilibity[sp.arange(n_samples), min_idx] = 1

            # M step
            center_new = sp.dot(X.T, responsilibity) / sp.sum(responsilibity, axis=0)
            diff = center_new - center
            print('K-Means: {0:5d} {1:4e}'.format(iter, spla.norm(diff) / spla.norm(center)))
            if (spla.norm(diff) < tol * spla.norm(center)):
                break

            center = center_new

        self.center = center.T
        self.responsibility = responsilibity

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号