Model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def get_layers(self, name, next_layer=False, last_layer=False, type=None):
        if type is None:
            name2layer = {layer.name: layer for layer in self.model.layers}
        else:
            name2layer = {}
            for layer in self.model.layers:
                for t in type:
                    if t.lower() in layer.name.lower():
                        name2layer[layer.name] = layer
                        break
                        # name2layer = {layer.name: layer for layer in self.model.layers if type.lower() in layer.name.lower()}

        def _get_layer(name):
            return name2layer[name]

        nodes = self.graph.get_nodes(name, next_layer, last_layer, type=type)
        if not isinstance(nodes, list):
            nodes = [nodes]
        '''
        for node in nodes:
            if node.name not in name2layer:
                embed()
        '''
        return map(_get_layer, [node.name for node in nodes])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号