def add_bias_to_conv(model, weights, out_dir):
# load the prototxt file as a protobuf message
with open(model) as n:
str1 = n.read()
msg2 = caffe_pb2.NetParameter()
text_format.Merge(str1, msg2)
for l2 in msg2.layer:
if l2.type == "Convolution":
if l2.convolution_param.bias_term is False:
l2.convolution_param.bias_term = True
l2.convolution_param.bias_filler.type = 'constant'
l2.convolution_param.bias_filler.value = 0.0 # actually default value
model_temp = os.path.join(out_dir, "model_temp.prototxt")
print "Saving temp model..."
with open(model_temp, 'w') as m:
m.write(text_format.MessageToString(msg2))
net_src = caffe.Net(model, weights, caffe.TEST)
net_des = caffe.Net(model_temp, caffe.TEST)
for l3 in net_src.params.keys():
for i in range(len(net_src.params[l3])):
net_des.params[l3][i].data[:] = net_src.params[l3][i].data[:]
# save weights with bias
weights_temp = os.path.join(out_dir, "weights_temp.caffemodel")
print "Saving temp weights..."
net_des.save(weights_temp)
return model_temp, weights_temp
评论列表
文章目录