tools.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:hart 作者: akosiorek 项目源码 文件源码
def get_minibatch(self):

        if self.minibatch is None:

            self.img_store = ImageStore(self.img_size, self.in_memory, self.storage_dtype)

            def get_single_sample():
                return process_entry(self.get(), 1, self.img_store,
                                     self.depth_folder, self.bbox_scale)

            n_channels = 3 + int(self.depth_folder is not None)
            shapes = [(None,) + tuple(self.img_size) + (n_channels,), (None, 1, 4),
                      (None, 1)]
            dtypes = [self.storage_dtype, tf.float32, tf.uint8]
            names = ['img', 'bbox', 'presence']

            sample, sample_queue_size = nct.run_py2tf_queue(get_single_sample, dtypes, shapes=shapes,
                                                            names=names, n_threads=self.n_threads,
                                                            capacity=2 * self.batch_size,
                                                            name='{}/py2tf_queue'.format(self.name))

            minibatch = tf.train.batch(sample, self.batch_size, dynamic_pad=True, capacity=2)

            for k, v in minibatch.iteritems():
                unpacked = tf.unstack(v)
                unpacked = [u[:, tf.newaxis] for u in unpacked]
                minibatch[k] = tf.concat(axis=1, values=unpacked)

            if self.storage_dtype != tf.float32:
                minibatch[names[0]] = tf.to_float(minibatch[names[0]])
                dtypes[0] = tf.float32

            queue = tf.FIFOQueue(2, dtypes, names=names)
            enqeue_op = queue.enqueue(minibatch)
            runner = tf.train.QueueRunner(queue, [enqeue_op] * 2)
            tf.train.add_queue_runner(runner)
            minibatch = queue.dequeue()
            for name, shape in zip(names, shapes):
                minibatch[name].set_shape((shape[0], self.batch_size) + shape[1:])

            self.minibatch = minibatch

        return self.minibatch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号