def process_data(sess, filenames):
"""
This script gen the input images(downsample) and labels(origin images)
"""
images_size = FLAGS.input_image_size
reader = tf.WholeFileReader()
filename_queue = tf.train.string_input_producer(filenames)
_, value = reader.read(filename_queue)
channels = FLAGS.image_channels
image = tf.image.decode_jpeg(
value, channels=channels, name="dataset_image")
# add data augmentation here
image.set_shape([None, None, channels])
image = tf.reshape(image, [1, images_size, images_size, 3])
image = tf.cast(image, tf.float32) / 255.0
K = FLAGS.scale
downsampled = tf.image.resize_area(
image, [images_size // K, images_size // K])
upsampled = tf.image.resize_area(downsampled, [images_size, images_size])
feature = tf.reshape(upsampled, [images_size, images_size, 3])
label = tf.reshape(image, [images_size, images_size, 3])
features, labels = tf.train.shuffle_batch(
[feature, label], batch_size=FLAGS.batch_size, num_threads=4, capacity=5000, min_after_dequeue=1000, name='labels_and_features')
tf.train.start_queue_runners(sess=sess)
print 'tag31', features.eval(), labels.get_shape()
return features, labels
评论列表
文章目录