base_enhancement.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def add_skipping(model):
    model_list = get_model_list(model)

    insert_idx = -1
    # TODO: need to get the output shape from the last layer, use it as a parameter
    for idx, layer in enumerate(model_list):
        if layer[0] == 'Conv2D' or layer[0] == 'InceptionBlock' or layer[0] == 'ResidualBlock':
            insert_idx = idx + 1
            if layer[0] == 'Conv2D':
                pre_output_shape = layer[2]
            else:
                pre_output_shape = layer[1]
    if insert_idx != -1:
        model_list.insert(insert_idx, ['ResidualBlock', pre_output_shape])

    new_model = Build(model_list)

    return new_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号