def data_augmentation_fn(input_image: tf.Tensor, label_image: tf.Tensor) -> (tf.Tensor, tf.Tensor):
with tf.name_scope('DataAugmentation'):
with tf.name_scope('random_flip_lr'):
sample = tf.random_uniform([], 0, 1)
label_image = tf.cond(sample > 0.5, lambda: tf.image.flip_left_right(label_image), lambda: label_image)
input_image = tf.cond(sample > 0.5, lambda: tf.image.flip_left_right(input_image), lambda: input_image)
with tf.name_scope('random_flip_ud'):
sample = tf.random_uniform([], 0, 1)
label_image = tf.cond(sample > 0.5, lambda: tf.image.flip_up_down(label_image), lambda: label_image)
input_image = tf.cond(sample > 0.5, lambda: tf.image.flip_up_down(input_image), lambda: input_image)
chanels = input_image.get_shape()[-1]
input_image = tf.image.random_contrast(input_image, lower=0.8, upper=1.0)
if chanels == 3:
input_image = tf.image.random_hue(input_image, max_delta=0.1)
input_image = tf.image.random_saturation(input_image, lower=0.8, upper=1.2)
return input_image, label_image
评论列表
文章目录