def preprocess(self, inputs):
"""Perform preprocess.
Args:
inputs: raw input to the model.
Returns:
preprocessed input data.
"""
preprocess_fn = self.get_preprocess_fn()
assert inputs.ndim == 3 or inputs.ndim == 4, "invalid image format for preprocessing"
if inputs.ndim == 3:
inputs = np.expand_dims(inputs, axis=0)
with tf.Graph().as_default() as cur_g:
input_tensor = tf.convert_to_tensor(inputs, dtype=tf.uint8)
all_inputs = tf.unstack(input_tensor)
processed_inputs = []
for cur_input in all_inputs:
new_input = preprocess_fn(cur_input, self.net_params.input_img_height,
self.net_params.input_img_width)
processed_inputs.append(new_input)
new_inputs = tf.stack(processed_inputs)
with tf.Session(graph=cur_g) as sess:
processed_inputs = sess.run(new_inputs)
return processed_inputs
评论列表
文章目录