def get_input_fn(filename):
"""Returns an `input_fn` for train and eval."""
def input_fn(params):
"""A simple input_fn using the experimental input pipeline."""
# Retrieves the batch size for the current shard. The # of shards is
# computed according to the input pipeline deployment. See
# `tf.contrib.tpu.RunConfig` for details.
batch_size = params["batch_size"]
def parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
"image_raw": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features["image_raw"], tf.uint8)
image.set_shape([28 * 28])
# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
label = tf.cast(features["label"], tf.int32)
return image, label
dataset = tf.data.TFRecordDataset(
filename, buffer_size=FLAGS.dataset_reader_buffer_size)
dataset = dataset.map(parser).cache().repeat()
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
images, labels = dataset.make_one_shot_iterator().get_next()
return images, labels
return input_fn
评论列表
文章目录