structure_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def test_estimate_tree(num_edges):
    set_random_seed(0)
    E = num_edges
    V = 1 + E
    grid = make_complete_graph(V)
    K = grid.shape[1]
    edge_logits = np.random.random([K]) - 0.5
    edges = estimate_tree(grid, edge_logits)

    # Check size.
    assert len(edges) == E
    for v in range(V):
        assert any(v in edge for edge in edges)

    # Check optimality.
    edges = tuple(edges)
    if V < len(TREE_GENERATORS):
        all_trees = get_spanning_trees(V)
        assert edges in all_trees
        all_trees = list(all_trees)
        logits = []
        for tree in all_trees:
            logits.append(
                sum(edge_logits[find_complete_edge(u, v)] for (u, v) in tree))
        expected = all_trees[np.argmax(logits)]
        assert edges == expected
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号