def bn_absorber_prototxt(model):
# load the prototxt file as a protobuf message
with open(model) as k:
str1 = k.read()
msg1 = caffe_pb2.NetParameter()
text_format.Merge(str1, msg1)
# search for bn layer and remove them
for i, l in enumerate(msg1.layer):
if l.type == "BN":
if msg1.layer[i].name == 'bn0_1':
continue
if msg1.layer[i - 1].type == 'Deconvolution':
continue
msg1.layer.remove(l)
msg1.layer[i].bottom.append(msg1.layer[i-1].top[0])
if len(msg1.layer[i].bottom) == 2:
msg1.layer[i].bottom.remove(msg1.layer[i].bottom[0])
elif len(msg1.layer[i].bottom) == 3:
if ('bn' in msg1.layer[i].bottom[0]) is True: # to remove just the layers with 'bn' in the name
msg1.layer[i].bottom.remove(msg1.layer[i].bottom[0])
elif ('bn' in msg1.layer[i].bottom[1]) is True:
msg1.layer[i].bottom.remove(msg1.layer[i].bottom[1])
else:
raise Exception("no bottom blob with name 'bn' present in {} layer".format(msg1.layer[i]))
else:
raise Exception("bn absorber does not support more than 2 input blobs for layer {}"
.format(msg1.layer[i]))
if msg1.layer[i].type == 'Upsample':
temp = msg1.layer[i].bottom[0]
msg1.layer[i].bottom[0] = msg1.layer[i].bottom[1]
msg1.layer[i].bottom[1] = temp
# l.bottom.append(l.top[0]) #msg1.layer[i-1].top
return msg1
评论列表
文章目录