def get_queue(nodes,
queue_type='fifo',
batch_size=256,
capacity=None,
min_after_dequeue=None,
shape_flag=True,
seed=0):
""" A generic queue for reading data
Built on top of https://indico.io/blog/tensorflow-data-input-part2-extensions/
"""
if capacity is None:
capacity = 2 * batch_size
if min_after_dequeue is None:
min_after_dequeue = capacity // 2
names = []
dtypes = []
shapes = []
for name in nodes.keys():
names.append(name)
dtypes.append(nodes[name].dtype)
if shape_flag:
shapes.append(nodes[name].get_shape()[1:])
else:
shapes.append(nodes[name].get_shape())
if batch_size==1:
shapes = None
if queue_type == 'random':
queue = tf.RandomShuffleQueue(capacity=capacity,
min_after_dequeue=min_after_dequeue,
dtypes=dtypes,
shapes=shapes,
names=names,
seed=seed)
elif queue_type == 'fifo':
queue = tf.FIFOQueue(capacity=capacity,
dtypes=dtypes,
shapes=shapes,
names=names)
elif queue_type == 'padding_fifo':
queue = tf.PaddingFIFOQueue(capacity=capacity,
dtypes=dtypes,
shapes=shapes,
names=names)
elif queue_type == 'priority':
queue = tf.PriorityQueue(capacity=capacity,
types=dtypes,
shapes=shapes,
names=names)
else:
Exception('Queue type %s not recognized' % queue_type)
return queue
评论列表
文章目录