def _gather_into_queue(*tensor_lists):
assert len(tensor_lists) % FLAGS.batch_size == 0
queue = tf.RandomShuffleQueue(FLAGS.batch_queue_capacity,
FLAGS.batch_queue_capacity // 2,
dtypes(*tensor_lists[0]))
collections.add_metric(queue.size(), "sorted_batches_in_queue")
add_queue_runner(
queue,
[tf.group(*[
queue.enqueue(transform.batch(
*tensor_lists[i:i + FLAGS.batch_size]))
for i in range(0, len(tensor_lists), FLAGS.batch_size)])])
results = queue.dequeue()
for result, tensor in zip(results, tensor_lists[0]):
result.set_shape([None, *static_shape(tensor)])
return results
评论列表
文章目录