def read_examples(input_files, batch_size, shuffle, num_epochs=None):
"""Creates readers and queues for reading example protos."""
files = []
for e in input_files:
for path in e.split(','):
files.extend(file_io.get_matching_files(path))
thread_count = multiprocessing.cpu_count()
# The minimum number of instances in a queue from which examples are drawn
# randomly. The larger this number, the more randomness at the expense of
# higher memory requirements.
min_after_dequeue = 1000
# When batching data, the queue's capacity will be larger than the batch_size
# by some factor. The recommended formula is (num_threads + a small safety
# margin). For now, we use a single thread for reading, so this can be small.
queue_size_multiplier = thread_count + 3
# Convert num_epochs == 0 -> num_epochs is None, if necessary
num_epochs = num_epochs or None
# Build a queue of the filenames to be read.
filename_queue = tf.train.string_input_producer(files, num_epochs, shuffle)
example_id, encoded_example = tf.TextLineReader().read_up_to(
filename_queue, batch_size)
if shuffle:
capacity = min_after_dequeue + queue_size_multiplier * batch_size
return tf.train.shuffle_batch(
[example_id, encoded_example],
batch_size,
capacity,
min_after_dequeue,
enqueue_many=True,
num_threads=thread_count)
else:
capacity = queue_size_multiplier * batch_size
return tf.train.batch(
[example_id, encoded_example],
batch_size,
capacity=capacity,
enqueue_many=True,
num_threads=thread_count)
# ==============================================================================
# Building the TF learn estimators
# ==============================================================================
评论列表
文章目录