input_reader_cifar10.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:easy-tensorflow 作者: khanhptnk 项目源码 文件源码
def read_input(self, data_path, batch_size, randomize_input=True,
                 distort_inputs=True, name="read_input"):
    """Read input labeled images and make a batch of examples.

      Labeled images are read from files of tf.Example protos. This proto has
      to contain two features: `image` and `label`, corresponding to an image
      and its label. After being read, the labeled images are put into queues
      to make a batch of examples every time the batching op is executed.

      Args:
        data_path: a string, path to files of tf.Example protos containing
          labeled images.
        batch_size: a int, number of labeled images in a batch.
        randomize_input: a bool, whether the images in the batch are randomized.
        distort_inputs: a bool, whether to distort the images.
        name: a string, name of the op.
      Returns:
        keys: a tensowflow op, the keys of tf.Example protos.
        examples: a tensorflow op, a batch of examples containing labeled
          images. After being materialized, this op becomes a dict, in which the
          `decoded_observation` key is an image and the `decoded_label` is the
          label of that image.
    """
    feature_types = {}
    feature_types["image"] = tf.FixedLenFeature(
        shape=[3072,], dtype=tf.int64, default_value=None)

    feature_types["label"] = tf.FixedLenFeature(
        shape=[1,], dtype=tf.int64, default_value=None)

    keys, examples = tf.contrib.learn.io.graph_io.read_keyed_batch_examples(
        file_pattern=data_path,
        batch_size=batch_size,
        reader=tf.TFRecordReader,
        randomize_input=randomize_input,
        queue_capacity=batch_size * 4,
        num_threads=10 if randomize_input else 1,
        parse_fn=lambda example_proto: self._preprocess_input(example_proto,
                                                              feature_types,
                                                              distort_inputs),
        name=name)

    return keys, examples
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号