topic_models.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:slda 作者: Savvysherpa 项目源码 文件源码
def fit(self, X, y):
        """
        Estimate the topic distributions per document (theta), term
        distributions per topic (phi), and regression coefficients (eta).

        Parameters
        ----------
        X : array-like, shape = (n_docs, n_terms)
            The document-term matrix.

        y : array-like, shape = (n_edges, 3)
            Each entry of y is an ordered triple (d_1, d_2, y_(d_1, d_2)),
            where d_1 and d_2 are documents and y_(d_1, d_2) is an indicator of
            a directed edge from d_1 to d_2.
        """

        self.doc_term_matrix = X
        self.n_docs, self.n_terms = X.shape
        self.n_tokens = X.sum()
        self.n_edges = y.shape[0]
        doc_lookup, term_lookup = self._create_lookups(X)
        # edge info
        y = np.ascontiguousarray(np.column_stack((range(self.n_edges), y)))
        # we use a view here so that we can sort in-place using named columns
        y_rec = y.view(dtype=list(zip(('index', 'tail', 'head', 'data'),
                                      4 * [y.dtype])))
        edge_tail = np.ascontiguousarray(y_rec['tail'].flatten(),
                                         dtype=np.intc)
        edge_head = np.ascontiguousarray(y_rec['head'].flatten(),
                                         dtype=np.intc)
        edge_data = np.ascontiguousarray(y_rec['data'].flatten(),
                                         dtype=np.float64)
        out_docs, out_edges = self._create_edges(y_rec, order='tail')
        in_docs, in_edges = self._create_edges(y_rec, order='head')
        # iterate
        self.theta, self.phi, self.H, self.loglikelihoods = gibbs_sampler_grtm(
            self.n_iter, self.n_report_iter, self.n_topics, self.n_docs,
            self.n_terms, self.n_tokens, self.n_edges, self.alpha, self.beta,
            self.mu, self.nu2, self.b, doc_lookup, term_lookup, out_docs,
            out_edges, in_docs, in_edges, edge_tail, edge_head, edge_data,
            self.seed)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号