mnist.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tpu-demos 作者: tensorflow 项目源码 文件源码
def get_input_fn(filename):
  """Returns an `input_fn` for train and eval."""

  def input_fn(params):
    """A simple input_fn using the experimental input pipeline."""
    # Retrieves the batch size for the current shard. The # of shards is
    # computed according to the input pipeline deployment. See
    # `tf.contrib.tpu.RunConfig` for details.
    batch_size = params["batch_size"]

    def parser(serialized_example):
      """Parses a single tf.Example into image and label tensors."""
      features = tf.parse_single_example(
          serialized_example,
          features={
              "image_raw": tf.FixedLenFeature([], tf.string),
              "label": tf.FixedLenFeature([], tf.int64),
          })
      image = tf.decode_raw(features["image_raw"], tf.uint8)
      image.set_shape([28 * 28])
      # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
      image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
      label = tf.cast(features["label"], tf.int32)
      return image, label

    dataset = tf.data.TFRecordDataset(
        filename, buffer_size=FLAGS.dataset_reader_buffer_size)
    dataset = dataset.map(parser).cache().repeat()
    dataset = dataset.apply(
        tf.contrib.data.batch_and_drop_remainder(batch_size))
    images, labels = dataset.make_one_shot_iterator().get_next()
    return images, labels
  return input_fn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号