test_surgeon.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:keras-surgeon 作者: BenWhetton 项目源码 文件源码
def test_delete_channels_merge_concatenate(channel_index, data_format):
    # This should test that the output is the correct shape so it should pass
    # into a Dense layer rather than a Conv layer.
    # The weighted layer is the previous layer,
    # Create model
    if data_format == 'channels_first':
        axis = 1
    elif data_format == 'channels_last':
        axis = -1
    else:
        raise ValueError

    input_shape = list(random.randint(10, 20, size=3))
    input_1 = Input(shape=input_shape)
    input_2 = Input(shape=input_shape)
    x = Conv2D(3, [3, 3], data_format=data_format, name='conv_1')(input_1)
    y = Conv2D(3, [3, 3], data_format=data_format, name='conv_2')(input_2)
    x = Concatenate(axis=axis, name='cat_1')([x, y])
    x = Flatten()(x)
    main_output = Dense(5, name='dense_1')(x)
    model = Model(inputs=[input_1, input_2], outputs=main_output)
    old_w = model.get_layer('dense_1').get_weights()

    # Delete channels
    layer = model.get_layer('cat_1')
    del_layer = model.get_layer('conv_1')
    surgeon = Surgeon(model, copy=True)
    surgeon.add_job('delete_channels', del_layer, channels=channel_index)
    new_model = surgeon.operate()
    new_w = new_model.get_layer('dense_1').get_weights()

    # Calculate next layer's correct weights
    flat_sz = np.prod(layer.get_output_shape_at(0)[1:])
    channel_count = getattr(del_layer, utils.get_channels_attr(del_layer))
    channel_index = [i % channel_count for i in channel_index]
    if data_format == 'channels_first':
        delete_indices = [x * flat_sz // 2 // channel_count + i for x in
                          channel_index
                          for i in range(0, flat_sz // 2 // channel_count, )]
    elif data_format == 'channels_last':
        delete_indices = [x + i for i in range(0, flat_sz, channel_count*2)
                          for x in channel_index]
    else:
        raise ValueError
    correct_w = model.get_layer('dense_1').get_weights()
    correct_w[0] = np.delete(correct_w[0], delete_indices, axis=0)

    assert weights_equal(correct_w, new_w)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号