def conv_wider(model):
model_list = get_model_list(model)
for idx, layer in enumerate(model_list):
if layer[0] == 'Conv2D':
wider_layer = layer
insert_idx = idx + 1
# wider operation: filters * 2
wider_layer[2] *= 2
# if next layer is residual layer, we need to change residual layer's input shape
while (model_list[insert_idx][0] == 'ResidualBlock'):
model_list[insert_idx][1] = wider_layer[2]
insert_idx = insert_idx + 1
new_model = Build(model_list)
return new_model
评论列表
文章目录