test_surgeon.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:keras-surgeon 作者: BenWhetton 项目源码 文件源码
def layer_test_helper_flatten_3d(layer, channel_index, data_format):
    # This should test that the output is the correct shape so it should pass
    # into a Dense layer rather than a Conv layer.
    # The weighted layer is the previous layer,
    # Create model
    main_input = Input(shape=list(random.randint(10, 20, size=4)))
    x = Conv3D(3, [3, 3, 2], data_format=data_format)(main_input)
    x = layer(x)
    x = Flatten()(x)
    main_output = Dense(5)(x)
    model = Model(inputs=main_input, outputs=main_output)

    # Delete channels
    del_layer_index = 1
    next_layer_index = 4
    del_layer = model.layers[del_layer_index]
    surgeon = Surgeon(model)
    surgeon.add_job('delete_channels', del_layer, channels=channel_index)
    new_model = surgeon.operate()
    new_w = new_model.layers[next_layer_index].get_weights()

    # Calculate next layer's correct weights
    flat_sz = np.prod(layer.get_output_shape_at(0)[1:])
    channel_count = getattr(del_layer, utils.get_channels_attr(del_layer))
    channel_index = [i % channel_count for i in channel_index]
    if data_format == 'channels_first':
        delete_indices = [x * flat_sz // channel_count + i for x in
                          channel_index
                          for i in range(0, flat_sz // channel_count, )]
    elif data_format == 'channels_last':
        delete_indices = [x + i for i in range(0, flat_sz, channel_count)
                          for x in channel_index]
    else:
        raise ValueError
    correct_w = model.layers[next_layer_index].get_weights()
    correct_w[0] = np.delete(correct_w[0], delete_indices, axis=0)

    assert weights_equal(correct_w, new_w)
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号