def sample_tree(self):
"""Samples a random tree.
Returns:
A pair (edges, edge_logits), where:
edges: A list of (vertex, vertex) pairs.
edge_logits: A [K]-shaped numpy array of edge logits.
"""
logger.info('TreeCatTrainer.sample_tree given %d rows',
len(self._added_rows))
SERIES.sample_tree_num_rows.append(len(self._added_rows))
complete_grid = self._tree.complete_grid
edge_logits = self.compute_edge_logits()
assert edge_logits.shape[0] == complete_grid.shape[1]
assert edge_logits.dtype == np.float32
edges = self.get_edges()
edges = sample_tree(complete_grid, edge_logits, edges)
return edges, edge_logits
评论列表
文章目录