input_reader_cifar10.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:easy-tensorflow 作者: khanhptnk 项目源码 文件源码
def _preprocess_input(self, example_proto, feature_types, distort_inputs):
    """Parse an tf.Example proto and preprocess its image and label.

      Args:
        example_proto: a tensorflow op, a tf.Example proto.
        feature_types: a dict, used for parsing a tf.Example proto. This is the
          same `feature_types` dict constructed in the `read_input` method.
        distort_inputs: a bool, whether to distort the images.
      Returns:
        example: a tensorflow op, after being materialized becomes a dict, in
          in which the `decoded_observation` key is a processed image, a tensor
          of size InputReaderCifar10.IMAGE_SIZE x
          InputReaderCifar10.IMAGE_SIZE x InputReaderCifar10.NUM_CHANNELS and
          the `decoded_label` is the label of that image, a vector of size
          InputReaderCifar10.NUM_CLASSES.
    """
    example = tf.parse_single_example(example_proto, feature_types)
    image = tf.reshape(example["image"], [InputReaderCifar10.NUM_CHANNELS,
                                          InputReaderCifar10.IMAGE_SIZE,
                                          InputReaderCifar10.IMAGE_SIZE])
    image = tf.transpose(image, perm=[1, 2, 0])
    image = tf.cast(image, tf.float32)
    if distort_inputs:
      image = tf.random_crop(image, [InputReaderCifar10.IMAGE_CROPPED_SIZE,
                                     InputReaderCifar10.IMAGE_CROPPED_SIZE,
                                     3])
      image = tf.image.random_flip_left_right(image)
      image = tf.image.random_brightness(image, max_delta=63)
      image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
    else:
      image = tf.image.resize_image_with_crop_or_pad(image,
          InputReaderCifar10.IMAGE_CROPPED_SIZE,
          InputReaderCifar10.IMAGE_CROPPED_SIZE)
    image = tf.image.per_image_whitening(image)
    example["decoded_observation"] = image

    label = tf.one_hot(example["label"], InputReaderCifar10.NUM_CLASSES, on_value=1, off_value=0)
    label = tf.reshape(label, [InputReaderCifar10.NUM_CLASSES,])
    label = tf.cast(label, tf.int64)
    example["decoded_label"] = label

    return example
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号