training.py 文件源码

python
阅读 47 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def sample_tree(self):
        """Samples a random tree.

        Returns:
            A pair (edges, edge_logits), where:
                edges: A list of (vertex, vertex) pairs.
                edge_logits: A [K]-shaped numpy array of edge logits.
        """
        logger.info('TreeCatTrainer.sample_tree given %d rows',
                    len(self._added_rows))
        SERIES.sample_tree_num_rows.append(len(self._added_rows))
        complete_grid = self._tree.complete_grid
        edge_logits = self.compute_edge_logits()
        assert edge_logits.shape[0] == complete_grid.shape[1]
        assert edge_logits.dtype == np.float32
        edges = self.get_edges()
        edges = sample_tree(complete_grid, edge_logits, edges)
        return edges, edge_logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号