def input_pipeline_dis(filenames, batch_size, read_threads=4, num_epochs=None, is_training=True):
filename_queue = tf.train.string_input_producer(
filenames, num_epochs=FLAGS.num_epochs, shuffle=is_training)
# initialize local variables if num_epochs is not None or it'll raise uninitialized problem
tf.get_default_session().run(tf.local_variables_initializer())
example_list = [read_my_file_format_dis(filename_queue, is_training) \
for _ in range(read_threads)]
min_after_dequeue = 300 if is_training else 10
capacity = min_after_dequeue + 3 * batch_size
clip_batch, label_batch, text_batch = tf.train.shuffle_batch_join(
example_list, batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return clip_batch, label_batch, text_batch
评论列表
文章目录