def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
#print layer.type
#print type(layer.top)
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
compute_bn_statistics_depth.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录