def read_input(self, data_path, batch_size, randomize_input=True,
distort_inputs=True, name="read_input"):
"""Read input labeled images and make a batch of examples.
Labeled images are read from files of tf.Example protos. This proto has
to contain two features: `image` and `label`, corresponding to an image
and its label. After being read, the labeled images are put into queues
to make a batch of examples every time the batching op is executed.
Args:
data_path: a string, path to files of tf.Example protos containing
labeled images.
batch_size: a int, number of labeled images in a batch.
randomize_input: a bool, whether the images in the batch are randomized.
distort_inputs: a bool, whether to distort the images.
name: a string, name of the op.
Returns:
keys: a tensowflow op, the keys of tf.Example protos.
examples: a tensorflow op, a batch of examples containing labeled
images. After being materialized, this op becomes a dict, in which the
`decoded_observation` key is an image and the `decoded_label` is the
label of that image.
"""
feature_types = {}
feature_types["image"] = tf.FixedLenFeature(
shape=[3072,], dtype=tf.int64, default_value=None)
feature_types["label"] = tf.FixedLenFeature(
shape=[1,], dtype=tf.int64, default_value=None)
keys, examples = tf.contrib.learn.io.graph_io.read_keyed_batch_examples(
file_pattern=data_path,
batch_size=batch_size,
reader=tf.TFRecordReader,
randomize_input=randomize_input,
queue_capacity=batch_size * 4,
num_threads=10 if randomize_input else 1,
parse_fn=lambda example_proto: self._preprocess_input(example_proto,
feature_types,
distort_inputs),
name=name)
return keys, examples
评论列表
文章目录