test_surgeon.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:keras-surgeon 作者: BenWhetton 项目源码 文件源码
def recursive_test_helper(layer, channel_index):
    main_input = Input(shape=[32, 10])
    x = layer(main_input)
    x = GRU(4, return_sequences=False)(x)
    main_output = Dense(5)(x)
    model = Model(inputs=main_input, outputs=main_output)

    # Delete channels
    del_layer_index = 1
    next_layer_index = 2
    del_layer = model.layers[del_layer_index]
    new_model = operations.delete_channels(model, del_layer, channel_index)
    new_w = new_model.layers[next_layer_index].get_weights()

    # Calculate next layer's correct weights
    channel_count = getattr(del_layer, utils.get_channels_attr(del_layer))
    channel_index = [i % channel_count for i in channel_index]
    correct_w = model.layers[next_layer_index].get_weights()
    correct_w[0] = np.delete(correct_w[0], channel_index, axis=0)

    assert weights_equal(correct_w, new_w)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号