Net2Net.py 文件源码

python
阅读 51 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def add_skip(self, model, config):
        assert nx.is_directed_acyclic_graph(model.graph)
        topo_nodes = nx.topological_sort(model.graph)

        names = [node.name for node in topo_nodes
                 if node.type == 'Conv2D' or node.type == 'Group' or node.type == 'Conv2D_Pooling']

        if len(names) <= 2:
            logger.info('can\'t find a suitable layer to apply add_skip operation,return origin model')
            return model, False

        max_iter = 100
        for i in range(max_iter + 1):
            if i == max_iter:
                logger.info('can\'t find a suitable layer to apply add_skip operation,return origin model')
                return model, False
            from_idx = np.random.randint(0, len(names) - 2)
            to_idx = from_idx + 1
            next_nodes = model.graph.get_nodes(names[to_idx], next_layer=True, last_layer=False)
            if 'Add' in [node.type for node in next_nodes]:
                continue
            else:
                break

        from_name = names[from_idx]
        to_name = names[to_idx]
        logger.info('choose {} and {} to add_skip'.format(from_name, to_name))
        return self.skip(model, from_name, to_name, config), True

    # add group operation
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号