def test_estimate_tree(num_edges):
set_random_seed(0)
E = num_edges
V = 1 + E
grid = make_complete_graph(V)
K = grid.shape[1]
edge_logits = np.random.random([K]) - 0.5
edges = estimate_tree(grid, edge_logits)
# Check size.
assert len(edges) == E
for v in range(V):
assert any(v in edge for edge in edges)
# Check optimality.
edges = tuple(edges)
if V < len(TREE_GENERATORS):
all_trees = get_spanning_trees(V)
assert edges in all_trees
all_trees = list(all_trees)
logits = []
for tree in all_trees:
logits.append(
sum(edge_logits[find_complete_edge(u, v)] for (u, v) in tree))
expected = all_trees[np.argmax(logits)]
assert edges == expected
评论列表
文章目录