mnist.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tensorflow-input-pipelines 作者: ischlag 项目源码 文件源码
def __build_generic_data_tensor(self, raw_images, raw_targets, shuffle, augmentation):
    """ Creates the input pipeline and performs some preprocessing. """

    images = ops.convert_to_tensor(raw_images)
    targets = ops.convert_to_tensor(raw_targets)

    set_size = raw_images.shape[0]

    images = tf.reshape(images, [set_size, 28, 28, 1])
    image, label = tf.train.slice_input_producer([images, targets], shuffle=shuffle)

    # Data Augmentation
    if augmentation:
      image = tf.image.resize_image_with_crop_or_pad(image, self.IMAGE_HEIGHT+4, self.IMAGE_WIDTH+4)
      image = tf.random_crop(image, [self.IMAGE_HEIGHT, self.IMAGE_WIDTH, self.NUM_OF_CHANNELS])
      image = tf.image.random_flip_left_right(image)

    image = tf.image.per_image_standardization(image)

    images_batch, labels_batch = tf.train.batch([image, label], batch_size=self.batch_size, num_threads=self.NUM_THREADS)

    return images_batch, labels_batch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号