def layer_test_helper_flatten_3d(layer, channel_index, data_format):
# This should test that the output is the correct shape so it should pass
# into a Dense layer rather than a Conv layer.
# The weighted layer is the previous layer,
# Create model
main_input = Input(shape=list(random.randint(10, 20, size=4)))
x = Conv3D(3, [3, 3, 2], data_format=data_format)(main_input)
x = layer(x)
x = Flatten()(x)
main_output = Dense(5)(x)
model = Model(inputs=main_input, outputs=main_output)
# Delete channels
del_layer_index = 1
next_layer_index = 4
del_layer = model.layers[del_layer_index]
surgeon = Surgeon(model)
surgeon.add_job('delete_channels', del_layer, channels=channel_index)
new_model = surgeon.operate()
new_w = new_model.layers[next_layer_index].get_weights()
# Calculate next layer's correct weights
flat_sz = np.prod(layer.get_output_shape_at(0)[1:])
channel_count = getattr(del_layer, utils.get_channels_attr(del_layer))
channel_index = [i % channel_count for i in channel_index]
if data_format == 'channels_first':
delete_indices = [x * flat_sz // channel_count + i for x in
channel_index
for i in range(0, flat_sz // channel_count, )]
elif data_format == 'channels_last':
delete_indices = [x + i for i in range(0, flat_sz, channel_count)
for x in channel_index]
else:
raise ValueError
correct_w = model.layers[next_layer_index].get_weights()
correct_w[0] = np.delete(correct_w[0], delete_indices, axis=0)
assert weights_equal(correct_w, new_w)
评论列表
文章目录