plotting.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def nx_plot_tree(server, node_size=200, **options):
    """Visualize the tree using the networkx package.

    This plots to the current matplotlib figure.

    Args:
        server: A DataServer instance.
        options: Options passed to networkx.draw().
    """
    import networkx as nx
    edges = server.estimate_tree()
    perplexity = server.latent_perplexity()
    feature_names = server.feature_names

    V = 1 + len(edges)
    G = nx.Graph()
    G.add_nodes_from(range(V))
    G.add_edges_from(edges)
    H = nx.relabel_nodes(G, dict(enumerate(feature_names)))
    node_size = node_size * perplexity / perplexity.max()

    options.setdefault('alpha', 0.2)
    options.setdefault('font_size', 8)
    nx.draw(H, with_labels=True, node_size=node_size, **options)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号