def recursive_test_helper(layer, channel_index):
main_input = Input(shape=[32, 10])
x = layer(main_input)
x = GRU(4, return_sequences=False)(x)
main_output = Dense(5)(x)
model = Model(inputs=main_input, outputs=main_output)
# Delete channels
del_layer_index = 1
next_layer_index = 2
del_layer = model.layers[del_layer_index]
new_model = operations.delete_channels(model, del_layer, channel_index)
new_w = new_model.layers[next_layer_index].get_weights()
# Calculate next layer's correct weights
channel_count = getattr(del_layer, utils.get_channels_attr(del_layer))
channel_index = [i % channel_count for i in channel_index]
correct_w = model.layers[next_layer_index].get_weights()
correct_w[0] = np.delete(correct_w[0], channel_index, axis=0)
assert weights_equal(correct_w, new_w)
评论列表
文章目录