def _shuffle(inputs, capacity, min_after_dequeue, num_threads):
if isinstance(inputs, dict):
names, dtypes = zip(*[(key, input_.dtype)
for key, input_ in inputs.items()])
else:
dtypes = [input_.dtype for input_ in inputs]
queue = tf.RandomShuffleQueue(
capacity,
min_after_dequeue,
dtypes,
**({'names': names} if isinstance(inputs, dict) else {}))
tf.train.add_queue_runner(tf.train.QueueRunner(
queue,
[queue.enqueue(inputs)] * num_threads))
shuffled_inputs = queue.dequeue()
for key, input_ in (inputs.items()
if isinstance(inputs, dict) else
enumerate(inputs)):
shuffled_inputs[key].set_shape(input_.get_shape())
return shuffled_inputs
评论列表
文章目录