iterators.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:ml-pyxis 作者: vicolab 项目源码 文件源码
def batch(self):
        """Return a batch of samples sampled uniformly from the database.

        Returns
        -------
        (numpy.ndarray, ...)
            The sample values are returned in a tuple in the order of the
            `keys` specified by the user.
        """
        # Count the number of keys (i.e. data objects)
        nb_keys = len(self.keys)

        data = []
        for key in self.keys:
            data.append(np.zeros((self.batch_size,) + self.spec[key]['shape'],
                        dtype=self.spec[key]['dtype']))

        while True:
            # Sample indices uniformly
            batch_idxs = self.rng.randint(self.db.nb_samples,
                                          size=self.batch_size,
                                          dtype=np.uint64)

            for i, v in enumerate(batch_idxs):
                sample = self.db.get_sample(v)
                for k in range(nb_keys):
                    data[k][i] = sample[self.keys[k]]

            # Account for batches with only one key
            if 1 == len(data):
                yield tuple(data)[0]
            else:
                yield tuple(data)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号