input_data.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tf-tutorial 作者: zchen0211 项目源码 文件源码
def __init__(self, train=True):
        self.train = train # training mode or not

        self.dataX = tf.placeholder(dtype=tf.float32, shape=[FLAGS.raw_size, FLAGS.raw_size, 3])
        self.dataY = tf.placeholder(dtype=tf.int64, shape=[])

        # get the mean. 
        mean_ = np.load(os.path.join(FLAGS.data_dir, FLAGS.mean_file))
        mean_ = mean_['data_mean'].astype(np.float32)
        self.mean_dataX = tf.constant(mean_, dtype=tf.float32)

        # mean subtraction
        self.mean_sub_image = self.dataX - self.mean_dataX

        # The actual queue of data. The queue contains a vector for an image and a scalar label.
        if self.train:
            self.queue = tf.RandomShuffleQueue(shapes=[[FLAGS.crop_size, FLAGS.crop_size, 3], []],
                                               dtypes=[tf.float32, tf.int64], capacity=2000, min_after_dequeue=1000)
            # random crop
            self.distorted_image = tf.random_crop(self.mean_sub_image, [FLAGS.crop_size, FLAGS.crop_size, 3])
            # random flip
            self.distorted_image = tf.image.random_flip_left_right(self.distorted_image)
            # random brightness, saturation and contrast
            self.distorted_image = tf.image.random_brightness(self.distorted_image, max_delta=63. / 255.)
            self.distorted_image = tf.image.random_saturation(self.distorted_image, lower=0.5, upper=1.5)
            self.distorted_image = tf.image.random_contrast(self.distorted_image, lower=0.2, upper=1.8)
        else:
            self.queue = tf.FIFOQueue(shapes=[[FLAGS.crop_size, FLAGS.crop_size, 3], []],
                                               dtypes=[tf.float32, tf.int64], capacity=20000)
            # center crop
            self.distorted_image = tf.image.resize_image_with_crop_or_pad(self.mean_sub_image, FLAGS.crop_size, FLAGS.crop_size)
            # tf.image.central_crop(image, central_fraction)

        # enqueue
        self.enqueue_op = self.queue.enqueue([self.distorted_image, self.dataY])
        #self.enqueue_op = self.queue.enqueue([self.dataX, self.dataY])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号