def get_layers(self, name, next_layer=False, last_layer=False, type=None):
if type is None:
name2layer = {layer.name: layer for layer in self.model.layers}
else:
name2layer = {}
for layer in self.model.layers:
for t in type:
if t.lower() in layer.name.lower():
name2layer[layer.name] = layer
break
# name2layer = {layer.name: layer for layer in self.model.layers if type.lower() in layer.name.lower()}
def _get_layer(name):
return name2layer[name]
nodes = self.graph.get_nodes(name, next_layer, last_layer, type=type)
if not isinstance(nodes, list):
nodes = [nodes]
'''
for node in nodes:
if node.name not in name2layer:
embed()
'''
return map(_get_layer, [node.name for node in nodes])
评论列表
文章目录