def save_visualization(name, format='jpg'):
g = graphviz.Digraph(format=format)
def sizestr(var):
size = [int(i) for i in list(var.size())]
return str(size)
# add variable nodes
for vid, var in vars.items():
if isinstance(var, nn.Parameter):
g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='red')
elif isinstance(var, Variable):
g.node(str(vid), label=sizestr(var), shape='ellipse', style='filled', fillcolor='lightblue')
else:
assert False, var.__class__
# add creator nodes
for cid in func_trace:
creator = funcs[cid]
g.node(str(cid), label=str(creator.__class__.__name__), shape='rectangle', style='filled', fillcolor='orange')
# add edges between creator and inputs
for cid in func_trace:
for iid in func_trace[cid]:
g.edge(str(iid), str(cid))
# add edges between outputs and creators
for oid in var_trace:
for cid in var_trace[oid]:
g.edge(str(cid), str(oid))
g.render(name)
评论列表
文章目录