rvm.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:prml 作者: Yevgnen 项目源码 文件源码
def fit(self, X, T, max_iter=int(1e2), tol=1e-3, bound=1e10):
        """Fit a RVM model with the training data ``(X, T)``."""
        # Initialize the hyperparameters
        self._init_hyperparameters(X, T)

        # Compute design matrix
        n_samples = X.shape[0]
        phi = sp.c_[sp.ones(n_samples), self._compute_design_matrix(X)]  # Add x0

        alpha = self.cov
        beta = self.beta

        log_evidence = -1e10
        for iter in range(max_iter):
            alpha[alpha >= bound] = bound
            rv_indices = sp.nonzero(alpha < bound)[0]
            rv_phi = phi[:, rv_indices]
            rv_alpha = alpha[rv_indices]

            # Compute the posterior distribution
            post_cov = spla.inv(sp.diag(rv_alpha) + beta * sp.dot(rv_phi.T, rv_phi))
            post_mean = beta * sp.dot(post_cov, sp.dot(rv_phi.T, T))

            # Re-estimate the hyperparameters
            gamma = 1 - rv_alpha * post_cov.diagonal()
            rv_alpha = gamma / (post_mean * post_mean)
            beta = (n_samples + 1 - gamma.sum()) / spla.norm(T - sp.dot(rv_phi, post_mean))**2

            # Evalueate the log evidence and test the relative change
            C = sp.eye(rv_phi.shape[0]) / beta + rv_phi.dot(sp.diag(1.0 / rv_alpha)).dot(rv_phi.T)
            log_evidence_new = -0.5 * (sp.log(spla.det(C)) + T.dot(spla.inv(C)).dot((T)))
            diff = spla.norm(log_evidence_new - log_evidence)
            if (diff < tol * spla.norm(log_evidence)):
                break

            log_evidence = log_evidence_new
            alpha[rv_indices] = rv_alpha

        # Should re-compute the posterior distribution
        self.rv_indices = rv_indices
        self.cov = post_cov
        self.mean = post_mean
        self.beta = beta

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号