def add_skip(self, model, config):
assert nx.is_directed_acyclic_graph(model.graph)
topo_nodes = nx.topological_sort(model.graph)
names = [node.name for node in topo_nodes
if node.type == 'Conv2D' or node.type == 'Group' or node.type == 'Conv2D_Pooling']
if len(names) <= 2:
logger.info('can\'t find a suitable layer to apply add_skip operation,return origin model')
return model, False
max_iter = 100
for i in range(max_iter + 1):
if i == max_iter:
logger.info('can\'t find a suitable layer to apply add_skip operation,return origin model')
return model, False
from_idx = np.random.randint(0, len(names) - 2)
to_idx = from_idx + 1
next_nodes = model.graph.get_nodes(names[to_idx], next_layer=True, last_layer=False)
if 'Add' in [node.type for node in next_nodes]:
continue
else:
break
from_name = names[from_idx]
to_name = names[to_idx]
logger.info('choose {} and {} to add_skip'.format(from_name, to_name))
return self.skip(model, from_name, to_name, config), True
# add group operation
评论列表
文章目录