def distorted_inputs():
data = load_data(FLAGS.data_dir)
filenames = [ d['filename'] for d in data ]
label_indexes = [ d['label_index'] for d in data ]
filename, label_index = tf.train.slice_input_producer([filenames, label_indexes], shuffle=True)
num_preprocess_threads = 4
images_and_labels = []
for thread_id in range(num_preprocess_threads):
image_buffer = tf.read_file(filename)
bbox = []
train = True
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=FLAGS.batch_size,
capacity=2 * num_preprocess_threads * FLAGS.batch_size)
height = FLAGS.input_size
width = FLAGS.input_size
depth = 3
images = tf.cast(images, tf.float32)
images = tf.reshape(images, shape=[FLAGS.batch_size, height, width, depth])
return images, tf.reshape(label_index_batch, [FLAGS.batch_size])
评论列表
文章目录