def _preprocess_input(self, example_proto, feature_types, distort_inputs):
"""Parse an tf.Example proto and preprocess its image and label.
Args:
example_proto: a tensorflow op, a tf.Example proto.
feature_types: a dict, used for parsing a tf.Example proto. This is the
same `feature_types` dict constructed in the `read_input` method.
distort_inputs: a bool, whether to distort the images.
Returns:
example: a tensorflow op, after being materialized becomes a dict, in
in which the `decoded_observation` key is a processed image, a tensor
of size InputReaderCifar10.IMAGE_SIZE x
InputReaderCifar10.IMAGE_SIZE x InputReaderCifar10.NUM_CHANNELS and
the `decoded_label` is the label of that image, a vector of size
InputReaderCifar10.NUM_CLASSES.
"""
example = tf.parse_single_example(example_proto, feature_types)
image = tf.reshape(example["image"], [InputReaderCifar10.NUM_CHANNELS,
InputReaderCifar10.IMAGE_SIZE,
InputReaderCifar10.IMAGE_SIZE])
image = tf.transpose(image, perm=[1, 2, 0])
image = tf.cast(image, tf.float32)
if distort_inputs:
image = tf.random_crop(image, [InputReaderCifar10.IMAGE_CROPPED_SIZE,
InputReaderCifar10.IMAGE_CROPPED_SIZE,
3])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=63)
image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
else:
image = tf.image.resize_image_with_crop_or_pad(image,
InputReaderCifar10.IMAGE_CROPPED_SIZE,
InputReaderCifar10.IMAGE_CROPPED_SIZE)
image = tf.image.per_image_whitening(image)
example["decoded_observation"] = image
label = tf.one_hot(example["label"], InputReaderCifar10.NUM_CLASSES, on_value=1, off_value=0)
label = tf.reshape(label, [InputReaderCifar10.NUM_CLASSES,])
label = tf.cast(label, tf.int64)
example["decoded_label"] = label
return example
评论列表
文章目录