BN-absorber-enet.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ENet 作者: TimoSaemann 项目源码 文件源码
def add_bias_to_conv(model, weights, out_dir):
    # load the prototxt file as a protobuf message
    with open(model) as n:
        str1 = n.read()
    msg2 = caffe_pb2.NetParameter()
    text_format.Merge(str1, msg2)

    for l2 in msg2.layer:
        if l2.type == "Convolution":
            if l2.convolution_param.bias_term is False:
                l2.convolution_param.bias_term = True
                l2.convolution_param.bias_filler.type = 'constant'
                l2.convolution_param.bias_filler.value = 0.0  # actually default value

    model_temp = os.path.join(out_dir, "model_temp.prototxt")
    print "Saving temp model..."
    with open(model_temp, 'w') as m:
        m.write(text_format.MessageToString(msg2))

    net_src = caffe.Net(model, weights, caffe.TEST)
    net_des = caffe.Net(model_temp, caffe.TEST)

    for l3 in net_src.params.keys():
        for i in range(len(net_src.params[l3])):

            net_des.params[l3][i].data[:] = net_src.params[l3][i].data[:]

    # save weights with bias
    weights_temp = os.path.join(out_dir, "weights_temp.caffemodel")
    print "Saving temp weights..."
    net_des.save(weights_temp)

    return model_temp, weights_temp
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号