Net2Net.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def wider(self, model, config):
        topo_nodes = nx.topological_sort(model.graph)
        names = [node.name
                 for node in topo_nodes
                 if
                 node.type == 'Conv2D' or node.type == 'Conv2D_Pooling' or node.type == 'Group']  # support group layer to wider
                 #node.type == 'Conv2D' or node.type == 'Conv2D_Pooling']
        max_iter = 100
        for i in range(max_iter + 1):
            if i == max_iter:
                logger.info('can\'t find a suitable layer to apply wider operation,return origin model')
                return model, False
            # random choose a layer to wider, except last conv layer
            choice = names[np.random.randint(0, len(names) - 1)]
            cur_node = model.graph.get_nodes(choice)[0]
            next_nodes = model.graph.get_nodes(choice, next_layer=True, last_layer=False)
            if 'Conv2D' in [node.type for node in next_nodes] or 'Conv2D_Pooling' in [node.type for node in next_nodes]:
                break
            else:
                continue

        cur_width = cur_node.config['filters']

        # for test
        # enlarge the max_cur_width
        #max_cur_width = (int((config.model_max_conv_width - config.model_min_conv_width) * cur_node.depth / config.model_max_depth) \
        #                + config.model_min_conv_width) * 5

        # for test
        max_cur_width = 1024

        width_ratio = np.random.rand()
        new_width = int(cur_width + width_ratio * (max_cur_width - cur_width))
        if cur_node.type == 'Group':
            # make sure that new_width % group_num == 0
            new_width = new_width // cur_node.config['group_num'] * cur_node.config['group_num']

        if new_width <= cur_width:
            logger.info('{} layer\'s width up to limit!'.format(choice))
            return model, False
        logger.info('choose {} to wider'.format(choice))
        if cur_node.type == 'Group':
            return self.wider_group_conv2d(model, layer_name=choice, new_width=new_width, config=config), True
        else:
            return self.wider_conv2d(model, layer_name=choice, new_width=new_width, config=config), True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号