def train_inputs(data_dir):
"""Construct input for CIFAR training.
Note that batch_size is a placeholder whose default value is the one
specified during training. It can however be specified differently at
inference time by passing it explicitly in the feed dict when sess.run is
called.
Args:
data_dir: Path to the CIFAR-10 data directory.
Returns:
images: Images. 4D tensor [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3].
labels: Labels. 1D tensor [batch_size].
"""
# Transpose dimensions
raw_image, label = get_raw_input_data(False, data_dir)
# If needed, perform data augmentation
if tf.app.flags.FLAGS.data_aug:
image = distort_image(raw_image)
else:
image = raw_image
# Normalize image (substract mean and divide by variance)
float_image = tf.image.per_image_standardization(image)
# Create a queue to extract batch of samples
batch_size_tensor = tf.placeholder_with_default(FLAGS.batch_size, shape=[])
images, labels = tf.train.shuffle_batch([float_image,label],
batch_size = batch_size_tensor,
num_threads = NUM_THREADS,
capacity = 20000 + 3 * FLAGS.batch_size,
min_after_dequeue = 20000)
# Display the training images in the visualizer
tf.summary.image('images', images)
return images, tf.reshape(labels, [-1])
评论列表
文章目录