def data_inputs(data_dir, data_type, batch_size, runner_threads):
# Input file reader
filenames = input_filenames(data_dir, data_type)
queue = tf.train.string_input_producer(filenames)
reader = tf.FixedLengthRecordReader(record_bytes=INPUT_RECORD_BYTES)
# Decode label and image
_key, record_raw = reader.read(queue)
record = tf.decode_raw(record_raw, tf.uint8)
label = tf.cast(tf.slice(record, [0], [INPUT_LABEL_BYTES]), tf.int32)
image = tf.reshape(
tf.slice(record, [INPUT_LABEL_BYTES], [INPUT_IMAGE_BYTES]),
[IMAGE_DEPTH, IMAGE_HEIGHT, IMAGE_WIDTH])
# Transpose image from stored DHW to HWD
image_hwd = tf.transpose(image, [1, 2, 0])
# Finalize image
image_float = tf.cast(image_hwd, tf.float32)
if data_type == AUGMENTED_TRAINING_DATA:
image_final = augmented_standardized_image(image_float)
else:
image_final = standardized_image(image_float)
# Process image and labels using queue runner
images, labels = tf.train.batch(
[image_final, label],
batch_size=batch_size,
num_threads=runner_threads,
capacity=10 * batch_size)
return images, tf.reshape(labels, [batch_size])
评论列表
文章目录