def predict(the_net,image):
inputs = []
if not os.path.exists(image):
raise Exception("Image path not exist")
return
try:
tmp_input = cv2.imread(image)
tmp_input = cv2.resize(tmp_input,(SIZE,SIZE))
tmp_input = tmp_input[11:11+128,11:11+128]
tmp_input = np.subtract(tmp_input,mean)
tmp_input = tmp_input.transpose((2, 0, 1))
tmp_input = np.require(tmp_input, dtype=np.float32)
except Exception as e:
#raise Exception("Image damaged or illegal file format")
return None
the_net.blobs['data'].reshape(1, *tmp_input.shape)
the_net.reshape()
the_net.blobs['data'].data[...] = tmp_input
the_net.forward()
scores = copy.deepcopy(the_net.blobs['feature'].data)
return scores
extract_res10.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录