def __next__(self):
"""
Returns a new minibatch of data with each call.
Yields:
tuple: The next minibatch which includes both features and labels.
"""
if self.index >= self.total_iterations:
raise StopIteration
i1 = (self.start + self.index * self.batch_size) % self.ndata
bsz = min(self.batch_size, self.ndata - i1)
oslice1 = slice(i1, i1 + bsz)
self.index += 1
if self.batch_size > bsz:
batch_bufs = {k: np.concatenate([src[oslice1], src[:self.batch_size - bsz]])
for k, src in self.data_arrays.items()}
else:
batch_bufs = {k: src[oslice1] for k, src in self.data_arrays.items()}
batch_bufs['iteration'] = self.index
return batch_bufs
评论列表
文章目录