graphgen.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:ExperimentPackage_PyTorch 作者: ICEORY 项目源码 文件源码
def addnodes(self, var):
        if var not in self.seen:
            if isinstance(var, Variable):
                value = '(' + (', ').join(['%d' % v for v in var.size()]) + ')'
                self.style_params["label"] = value
                self.dot.add_node(pydot.Node(str(id(var)), **self.style_params))

            else:
                value = str(type(var).__name__)
                self.style_layers["label"] = value
                if value == "ConvNd":
                    self.style_layers["fillcolor"] = "cyan"
                elif value == "BatchNorm":
                    self.style_layers["fillcolor"] = "darkseagreen"
                elif value == "Threshold":
                    self.style_layers["fillcolor"] = "crimson"
                    self.style_layers["label"] = "ReLU"
                elif value == "Add":
                    self.style_layers["fillcolor"] = "darkorchid"
                elif value == "AvgPool2d":
                    self.style_layers["fillcolor"] = "gold"
                elif value == "Linear":
                    self.style_layers["fillcolor"] = "chartreuse"
                elif value == "View":
                    self.style_layers["fillcolor"] = "brown"
                else:
                    self.style_layers["fillcolor"] = "aquamarine"

                self.dot.add_node(pydot.Node(str(id(var)), **self.style_layers))

            self.seen.add(var)
            if hasattr(var, 'previous_functions'):
                for u in var.previous_functions:
                    # if not isinstance(u[0], Variable):
                    self.dot.add_edge(pydot.Edge(str(id(u[0])), str(id(var))))
                    self.addnodes(u[0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号