GA.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def calc_choice_weight(self, evolution_choice_list, model):
        model_depth = 0
        for node in nx.topological_sort(model.graph):
            if node.type in ['Conv2D', 'Group', 'Conv2D_Pooling']:
                model_depth = model_depth + 1

        model_max_depth = model.config.model_max_depth
        max_pooling_limit = model.config.max_pooling_limit
        max_pooling_cnt = model.config.max_pooling_cnt

        weight = {}
        if 'deeper_with_pooling' in evolution_choice_list:
            weight['deeper_with_pooling'] = int(
                model_max_depth - model_max_depth / (2 * max_pooling_limit) * max_pooling_cnt)
        if 'deeper' in evolution_choice_list:
            weight['deeper'] = model_max_depth / 2
        if 'wider' in evolution_choice_list:
            weight['wider'] = model_max_depth / 2
        if 'add_skip' in evolution_choice_list:
            weight['add_skip'] = model_depth / 2 * 2
        if 'add_group' in evolution_choice_list:
            weight['add_group'] = model_depth / 2 * 2

        # choice_len = len(evolution_choice_list)
        # return [1] * choice_len # equal weight now
        return weight
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号