def update(self):
self.type2ind = {}
for node in self.nodes():
import re
ind = int(re.findall(r'^\w+?(\d+)$', node.name)[0])
self.type2ind[node.type] = self.type2ind.get(node.type, []) + [ind]
for node in nx.topological_sort(self):
if node.type in ['Conv2D', 'Group', 'Conv2D_Pooling']:
plus = 1
else:
plus = 0
if len(self.predecessors(node)) == 0:
node.depth = 0
else:
pre_depth = [_node.depth for _node in self.predecessors(node)]
pre_depth = max(pre_depth)
node.depth = self.max_depth = pre_depth + plus
评论列表
文章目录