arrayiterator.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def __next__(self):
        """
        Returns a new minibatch of data with each call.

        Yields:
            tuple: The next minibatch which includes both features and labels.
        """

        if self.index >= self.total_iterations:
            raise StopIteration

        i1 = (self.start + self.index * self.batch_size) % self.ndata
        bsz = min(self.batch_size, self.ndata - i1)
        oslice1 = slice(i1, i1 + bsz)
        self.index += 1

        if self.batch_size > bsz:
            batch_bufs = {k: np.concatenate([src[oslice1], src[:self.batch_size - bsz]])
                          for k, src in self.data_arrays.items()}
        else:
            batch_bufs = {k: src[oslice1] for k, src in self.data_arrays.items()}

        batch_bufs['iteration'] = self.index
        return batch_bufs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号