def nx_plot_tree(server, node_size=200, **options):
"""Visualize the tree using the networkx package.
This plots to the current matplotlib figure.
Args:
server: A DataServer instance.
options: Options passed to networkx.draw().
"""
import networkx as nx
edges = server.estimate_tree()
perplexity = server.latent_perplexity()
feature_names = server.feature_names
V = 1 + len(edges)
G = nx.Graph()
G.add_nodes_from(range(V))
G.add_edges_from(edges)
H = nx.relabel_nodes(G, dict(enumerate(feature_names)))
node_size = node_size * perplexity / perplexity.max()
options.setdefault('alpha', 0.2)
options.setdefault('font_size', 8)
nx.draw(H, with_labels=True, node_size=node_size, **options)
评论列表
文章目录