BN-absorber.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:DepthSegnet 作者: hari-sikchi 项目源码 文件源码
def bn_absorber_weights(model, weights):

    # load the prototxt file as a protobuf message
    with open(model) as f:
        str2 = f.read()
    msg = caffe_pb2.NetParameter()
    text_format.Merge(str2, msg)

    # load net
    net = caffe.Net(model, weights, caffe.TEST)

    # iterate over all layers of the network
    for i, layer in enumerate(msg.layer):

        # check if conv layer exist right before bn layer, otherwise merging is not possible and skip
        if not layer.type == 'BN':
            continue
        if not msg.layer[i-1].type == 'Convolution':
            continue

        # get the name of the bn and conv layer
        bn_layer = msg.layer[i].name
        conv_layer = msg.layer[i-1].name

        # get some necessary sizes
        kernel_size = 1
        shape_of_kernel_blob = net.params[conv_layer][0].data.shape
        number_of_feature_maps = list(shape_of_kernel_blob[0:1])
        shape_of_kernel_blob = list(shape_of_kernel_blob[1:4])
        for x in shape_of_kernel_blob:
            kernel_size *= x

        weight = copy_double(net.params[conv_layer][0].data)
        bias = copy_double(net.params[conv_layer][1].data)

        # receive new_gamma and new_beta which was already calculated by the compute_bn_statistics.py script
        new_gamma = net.params[bn_layer][0].data[...]
        new_beta = net.params[bn_layer][1].data[...]

        # manipulate the weights and biases over all feature maps:
        # weight_new = weight * gamma_new
        # bias_new = bias * gamma_new + beta_new
        # for more information see https://github.com/alexgkendall/caffe-segnet/issues/109
        for j in xrange(number_of_feature_maps[0]):

            net.params[conv_layer][0].data[j] = weight[j] * np.repeat(new_gamma.item(j), kernel_size).reshape(
                net.params[conv_layer][0].data[j].shape)
            net.params[conv_layer][1].data[j] = bias[j] * new_gamma.item(j) + new_beta.item(j)

        # set the no longer needed bn params to zero
        net.params[bn_layer][0].data[:] = 0
        net.params[bn_layer][1].data[:] = 0

    return net
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号