def get_minibatch(self):
if self.minibatch is None:
self.img_store = ImageStore(self.img_size, self.in_memory, self.storage_dtype)
def get_single_sample():
return process_entry(self.get(), 1, self.img_store,
self.depth_folder, self.bbox_scale)
n_channels = 3 + int(self.depth_folder is not None)
shapes = [(None,) + tuple(self.img_size) + (n_channels,), (None, 1, 4),
(None, 1)]
dtypes = [self.storage_dtype, tf.float32, tf.uint8]
names = ['img', 'bbox', 'presence']
sample, sample_queue_size = nct.run_py2tf_queue(get_single_sample, dtypes, shapes=shapes,
names=names, n_threads=self.n_threads,
capacity=2 * self.batch_size,
name='{}/py2tf_queue'.format(self.name))
minibatch = tf.train.batch(sample, self.batch_size, dynamic_pad=True, capacity=2)
for k, v in minibatch.iteritems():
unpacked = tf.unstack(v)
unpacked = [u[:, tf.newaxis] for u in unpacked]
minibatch[k] = tf.concat(axis=1, values=unpacked)
if self.storage_dtype != tf.float32:
minibatch[names[0]] = tf.to_float(minibatch[names[0]])
dtypes[0] = tf.float32
queue = tf.FIFOQueue(2, dtypes, names=names)
enqeue_op = queue.enqueue(minibatch)
runner = tf.train.QueueRunner(queue, [enqeue_op] * 2)
tf.train.add_queue_runner(runner)
minibatch = queue.dequeue()
for name, shape in zip(names, shapes):
minibatch[name].set_shape((shape[0], self.batch_size) + shape[1:])
self.minibatch = minibatch
return self.minibatch
评论列表
文章目录