def test_delete_channels_merge_concatenate(channel_index, data_format):
# This should test that the output is the correct shape so it should pass
# into a Dense layer rather than a Conv layer.
# The weighted layer is the previous layer,
# Create model
if data_format == 'channels_first':
axis = 1
elif data_format == 'channels_last':
axis = -1
else:
raise ValueError
input_shape = list(random.randint(10, 20, size=3))
input_1 = Input(shape=input_shape)
input_2 = Input(shape=input_shape)
x = Conv2D(3, [3, 3], data_format=data_format, name='conv_1')(input_1)
y = Conv2D(3, [3, 3], data_format=data_format, name='conv_2')(input_2)
x = Concatenate(axis=axis, name='cat_1')([x, y])
x = Flatten()(x)
main_output = Dense(5, name='dense_1')(x)
model = Model(inputs=[input_1, input_2], outputs=main_output)
old_w = model.get_layer('dense_1').get_weights()
# Delete channels
layer = model.get_layer('cat_1')
del_layer = model.get_layer('conv_1')
surgeon = Surgeon(model, copy=True)
surgeon.add_job('delete_channels', del_layer, channels=channel_index)
new_model = surgeon.operate()
new_w = new_model.get_layer('dense_1').get_weights()
# Calculate next layer's correct weights
flat_sz = np.prod(layer.get_output_shape_at(0)[1:])
channel_count = getattr(del_layer, utils.get_channels_attr(del_layer))
channel_index = [i % channel_count for i in channel_index]
if data_format == 'channels_first':
delete_indices = [x * flat_sz // 2 // channel_count + i for x in
channel_index
for i in range(0, flat_sz // 2 // channel_count, )]
elif data_format == 'channels_last':
delete_indices = [x + i for i in range(0, flat_sz, channel_count*2)
for x in channel_index]
else:
raise ValueError
correct_w = model.get_layer('dense_1').get_weights()
correct_w[0] = np.delete(correct_w[0], delete_indices, axis=0)
assert weights_equal(correct_w, new_w)
评论列表
文章目录