base_enhancement.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:NetworkCompress 作者: luzai 项目源码 文件源码
def get_model_list(model):
    model_list = []
    model_dict = json.loads(model.to_json())

    with open('model_frame.json', 'w') as outfile:
        json.dump(model_dict, outfile)

    model_layer = model_dict['config']['layers']

    for layer in model_layer:
        layer_name = layer['config']['name']
        layer_output_shape = model.get_layer(layer_name).output_shape
        if layer['class_name'] == 'InputLayer':
            model_list.append([layer['class_name'], layer['config']['batch_input_shape'][1:]])
        elif layer['class_name'] == 'Conv2D' and layer['config']['name'].startswith('conv'):
            model_list.append([layer['class_name'], layer['config']['kernel_size'], layer['config']['filters']])
        elif layer['class_name'] == 'Add' and layer['config']['name'].startswith('res'):
            model_list.append(['ResidualBlock', layer_output_shape[3]])
        elif layer['class_name'] == 'Concatenate':
            model_list.append(['InceptionBlock', layer_output_shape[3]])
        elif layer['class_name'] == 'GlobalMaxPooling2D':
            model_list.append([layer['class_name']])
        elif layer['class_name'] == 'Activation':
            model_list.append([layer['class_name']])

    return model_list
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号