def addnodes(self, var):
if var not in self.seen:
if isinstance(var, Variable):
value = '(' + (', ').join(['%d' % v for v in var.size()]) + ')'
self.style_params["label"] = value
self.dot.add_node(pydot.Node(str(id(var)), **self.style_params))
else:
value = str(type(var).__name__)
self.style_layers["label"] = value
if value == "ConvNd":
self.style_layers["fillcolor"] = "cyan"
elif value == "BatchNorm":
self.style_layers["fillcolor"] = "darkseagreen"
elif value == "Threshold":
self.style_layers["fillcolor"] = "crimson"
self.style_layers["label"] = "ReLU"
elif value == "Add":
self.style_layers["fillcolor"] = "darkorchid"
elif value == "AvgPool2d":
self.style_layers["fillcolor"] = "gold"
elif value == "Linear":
self.style_layers["fillcolor"] = "chartreuse"
elif value == "View":
self.style_layers["fillcolor"] = "brown"
else:
self.style_layers["fillcolor"] = "aquamarine"
self.dot.add_node(pydot.Node(str(id(var)), **self.style_layers))
self.seen.add(var)
if hasattr(var, 'previous_functions'):
for u in var.previous_functions:
# if not isinstance(u[0], Variable):
self.dot.add_edge(pydot.Edge(str(id(u[0])), str(id(var))))
self.addnodes(u[0])
评论列表
文章目录