def add_nodes(node_table, name=None, name_scope=None, style=True):
"""
Add TensorFlow graph's nodes to graphviz.dot.Digraph.
@param node_table
@param name
@param name_scope
@param style
@return graphviz.dot.Digraph object
"""
global CLUSTER_INDEX
if name:
digraph = tf_digraph(name=name, name_scope=name_scope, style=style)
else:
digraph = tf_digraph(name=str(uuid.uuid4().get_hex().upper()[0:6]), name_scope=name_scope, style=style)
graphs = []
for key, value in node_table.items():
if len(value) > 0:
sg = add_nodes(value, name='cluster_%i' % CLUSTER_INDEX, name_scope=key.split('/')[-1], style=style)
sg.node(key, key.split('/')[-1])
CLUSTER_INDEX += 1
graphs.append(sg)
else:
digraph.node(key, key.split('/')[-1])
for tg in graphs:
digraph.subgraph(tg)
return digraph
评论列表
文章目录