def _load_src_params_plain(self, pretrained_model):
""" Load parameters from the source model
All parameters are saved in a dictionary where
the keys are the original layer names
"""
# load pretrained model
with open(pretrained_model, 'rb') as f:
binary_content = f.read()
model = caffe_pb2.NetParameter()
model.ParseFromString(binary_content)
layers = model.layer
src_params = {}
for lc in layers:
name = lc.name
src_params[name] = [np.reshape(np.array(lc.blobs[i].data), lc.blobs[i].shape.dim) for i in xrange(len(lc.blobs))]
# if len(lc.blobs) >= 2:
# src_params[name] = [np.reshape(np.array(lc.blobs[0].data), lc.blobs[0].shape.dim),
# np.reshape(np.array(lc.blobs[1].data), lc.blobs[1].shape.dim)]
return src_params
评论列表
文章目录