def get_classifier(model_name, crop_size):
model_dir = model_name + '/'
MODEL_FILE = model_dir + 'deploy.prototxt'
PRETRAINED = model_dir + 'snapshot.caffemodel'
meanFile = model_dir + 'mean.binaryproto'
# Open mean.binaryproto file
blob = caffe.proto.caffe_pb2.BlobProto()
data = open(meanFile, 'rb').read()
blob.ParseFromString(data)
mean_arr = np.array(caffe.io.blobproto_to_array(blob)).reshape(1, crop_size, crop_size)
print mean_arr.shape
net = caffe.Classifier(MODEL_FILE, PRETRAINED, image_dims=(crop_size, crop_size), mean=mean_arr, raw_scale=255)
return net
评论列表
文章目录