def predict(self, img, flip_evaluation):
"""
Predict segementation for an image.
Arguments:
img: must be rowsxcolsx3
"""
h_ori, w_ori = img.shape[:2]
if img.shape[0:2] != self.input_shape:
print("Input %s not fitting for network size %s, resizing. You may want to try sliding prediction for better results." % (img.shape[0:2], self.input_shape))
img = misc.imresize(img, self.input_shape)
input_data = self.preprocess_image(img)
# utils.debug(self.model, input_data)
regular_prediction = self.model.predict(input_data)[0]
if flip_evaluation:
print("Predict flipped")
flipped_prediction = np.fliplr(self.model.predict(np.flip(input_data, axis=2))[0])
prediction = (regular_prediction + flipped_prediction) / 2.0
else:
prediction = regular_prediction
if img.shape[0:1] != self.input_shape: # upscale prediction if necessary
h, w = prediction.shape[:2]
prediction = ndimage.zoom(prediction, (1.*h_ori/h, 1.*w_ori/w, 1.),
order=1, prefilter=False)
return prediction
评论列表
文章目录