ProbabilisticModel.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:GRIPy 作者: giruenf 项目源码 文件源码
def fit(self, data):
        if self.n is None:
            means = []
            covs = []
            weights = []
            score = []
            for n in range(self.n_min, self.n_max):
                gmm = GMM(n_components=n, covariance_type=self.cv_type)
                gmm.fit(data)
                means.append(gmm.means_)
                if self.cv_type == 'full':
                    covs.append(gmm.covars_)
                elif self.cv_type == 'tied':
                    covs.append(np.tile(gmm.covars_, (n, 1, 1)))
                else:
                    covs.append(np.array([np.diag(cv) for cv in gmm.covars_]))
                weights.append(gmm.weights_)
                if self.n_estimator == 'BIC':
                    score.append(gmm.bic(data))

            i_best = self._chosebestformetric(self.n_estimator, score)

            self.means = means[i_best]
            self.covs = covs[i_best]
            self.weights = weights[i_best]

        else:
            gmm = GMM(n_components=self.n, covariance_type=self.cv_type)
            gmm.fit(data)
            self.means = gmm.means_
            if self.cv_type == 'full':
                self.covs = gmm.covars_
            elif self.cv_type == 'tied':
                self.covs = np.tile(gmm.covars_, (n, 1, 1))
            else:
                self.covs = np.array([np.diag(cv) for cv in gmm.covars_])
            self.weights = gmm.weights_
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号