def layer_test_helper_1d_global(layer, channel_index):
# This should test that the output is the correct shape so it should pass
# into a Dense layer rather than a Conv layer.
# The weighted layer is the previous layer,
# Create model
main_input = Input(shape=list(random.randint(10, 20, size=2)))
x = Conv1D(3, 3)(main_input)
x = layer(x)
main_output = Dense(5)(x)
model = Model(inputs=main_input, outputs=main_output)
# Delete channels
del_layer_index = 1
next_layer_index = 3
del_layer = model.layers[del_layer_index]
new_model = operations.delete_channels(model, del_layer, channel_index)
new_w = new_model.layers[next_layer_index].get_weights()
# Calculate next layer's correct weights
channel_count = getattr(del_layer, utils.get_channels_attr(del_layer))
channel_index = [i % channel_count for i in channel_index]
correct_w = model.layers[next_layer_index].get_weights()
correct_w[0] = np.delete(correct_w[0], channel_index, axis=0)
assert weights_equal(correct_w, new_w)
评论列表
文章目录