def queue_setup(filename, mode, batch_size, num_readers, min_examples):
""" Sets up the queue runners for data input """
filename_queue = tf.train.string_input_producer([filename], shuffle=True, capacity=16)
if mode == "train":
examples_queue = tf.RandomShuffleQueue(capacity=min_examples + 3 * batch_size,
min_after_dequeue=min_examples, dtypes=[tf.string])
else:
examples_queue = tf.FIFOQueue(capacity=min_examples + 3 * batch_size, dtypes=[tf.string])
enqueue_ops = list()
for _ in range(num_readers):
reader = tf.TFRecordReader()
_, value = reader.read(filename_queue)
enqueue_ops.append(examples_queue.enqueue([value]))
tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
example_serialized = examples_queue.dequeue()
return example_serialized
评论列表
文章目录