def __init__(self, train=True):
self.train = train # training mode or not
self.dataX = tf.placeholder(dtype=tf.float32, shape=[FLAGS.raw_size, FLAGS.raw_size, 3])
self.dataY = tf.placeholder(dtype=tf.int64, shape=[])
# get the mean.
mean_ = np.load(os.path.join(FLAGS.data_dir, FLAGS.mean_file))
mean_ = mean_['data_mean'].astype(np.float32)
self.mean_dataX = tf.constant(mean_, dtype=tf.float32)
# mean subtraction
self.mean_sub_image = self.dataX - self.mean_dataX
# The actual queue of data. The queue contains a vector for an image and a scalar label.
if self.train:
self.queue = tf.RandomShuffleQueue(shapes=[[FLAGS.crop_size, FLAGS.crop_size, 3], []],
dtypes=[tf.float32, tf.int64], capacity=2000, min_after_dequeue=1000)
# random crop
self.distorted_image = tf.random_crop(self.mean_sub_image, [FLAGS.crop_size, FLAGS.crop_size, 3])
# random flip
self.distorted_image = tf.image.random_flip_left_right(self.distorted_image)
# random brightness, saturation and contrast
self.distorted_image = tf.image.random_brightness(self.distorted_image, max_delta=63. / 255.)
self.distorted_image = tf.image.random_saturation(self.distorted_image, lower=0.5, upper=1.5)
self.distorted_image = tf.image.random_contrast(self.distorted_image, lower=0.2, upper=1.8)
else:
self.queue = tf.FIFOQueue(shapes=[[FLAGS.crop_size, FLAGS.crop_size, 3], []],
dtypes=[tf.float32, tf.int64], capacity=20000)
# center crop
self.distorted_image = tf.image.resize_image_with_crop_or_pad(self.mean_sub_image, FLAGS.crop_size, FLAGS.crop_size)
# tf.image.central_crop(image, central_fraction)
# enqueue
self.enqueue_op = self.queue.enqueue([self.distorted_image, self.dataY])
#self.enqueue_op = self.queue.enqueue([self.dataX, self.dataY])
评论列表
文章目录