base_enhancement.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def conv_wider(model):
    model_list = get_model_list(model)

    for idx, layer in enumerate(model_list):
        if layer[0] == 'Conv2D':
            wider_layer = layer
            insert_idx = idx + 1

    # wider operation: filters * 2
    wider_layer[2] *= 2

    # if next layer is residual layer, we need to change residual layer's input shape
    while (model_list[insert_idx][0] == 'ResidualBlock'):
        model_list[insert_idx][1] = wider_layer[2]
        insert_idx = insert_idx + 1

    new_model = Build(model_list)

    return new_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号