def get_preprocessing(name, is_training=False):
"""Returns preprocessing_fn(image, height, width, **kwargs).
Args:
name: The name of the preprocessing function.
is_training: `True` if the model is being used for training and `False`
otherwise.
Returns:
preprocessing_fn: A function that preprocessing a single image (pre-batch).
It has the following signature:
image = preprocessing_fn(image, output_height, output_width, ...).
Raises:
ValueError: If Preprocessing `name` is not recognized.
"""
preprocessing_fn_map = {
'vgg_ucf': vgg_ucf_preprocessing,
}
if name not in preprocessing_fn_map:
raise ValueError('Preprocessing name [%s] was not recognized' % name)
def preprocessing_fn(image, output_height, output_width, **kwargs):
with tf.variable_scope('preprocess_image'):
if len(image.get_shape()) == 3:
return preprocessing_fn_map[name].preprocess_image(
image, output_height, output_width, is_training=is_training, **kwargs)
elif len(image.get_shape()) == 4:
# preprocess all the images in one set in the same way by concat-ing
# them in channels
nImgs = image.get_shape().as_list()[0]
final_img_concat = preprocessing_fn_map[name].preprocess_image(
tf.concat(2, tf.unpack(image)),
output_height, output_width, is_training=is_training, **kwargs)
return tf.concat(0, tf.split(3, nImgs, final_img_concat))
else:
print('Incorrect dims image!')
return preprocessing_fn
评论列表
文章目录