def get_raw_input_data(test_data, data_dir):
"""Raw CIFAR10 input data ops using the Reader ops.
Args:
test_data: bool, indicating if one should use the test or train set.
data_dir: Path to the CIFAR-10 data directory.
Returns:
image: an op producing a 32x32x3 float32 image
label: an op producing an int32 label
"""
# Verify first that we have a valid data directory
if not os.path.exists(data_dir):
raise ValueError("Data directory %s doesn't exist" % data_dir)
# Construct a list of input file names
batches_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
if test_data:
filenames = [os.path.join(batches_dir, 'test_batch.bin')]
else:
filenames = [os.path.join(batches_dir, 'data_batch_%d.bin' %ii)
for ii in xrange(1, 6)]
# Make sure all input files actually exist
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# Create a string input producer to cycle over file names
filenames_queue = tf.train.string_input_producer(filenames)
# CIFAR data samples are stored as contiguous labels and images
label_size = 1
image_size = IMAGE_DEPTH * IMAGE_HEIGHT * IMAGE_WIDTH
# Instantiate a fixed length file reader
reader = tf.FixedLengthRecordReader(label_size + image_size)
# Read from files
key, value = reader.read(filenames_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
# Extract label and cast to int32
label = tf.cast(tf.slice(record_bytes, [0], [label_size]), tf.int32)
# Extract image and cast to float32
image = tf.cast(tf.slice(record_bytes,
[label_size],
[image_size]),
tf.float32)
# Images are stored as D x H x W vectors, but we want H x W x D
# So we need to convert to a matrix
image = tf.reshape(image, (IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH))
# Transpose dimensions
image = tf.transpose(image, (1, 2, 0))
return (image, label)
评论列表
文章目录