def __iter__(self):
"""
Returns a new minibatch of data with each call.
Yields:
dictionary: The next minibatch
samples[key]: numpy array with shape (batch_size, seq_len, feature_dim)
"""
while self.current_iter < self.total_iterations:
for batch_idx in range(self.batch_size):
if self.shuffle:
strt_idx = self.start + (self.current_iter * self.stride)
seq_start = strt_idx + (batch_idx * self.nbatches * self.seq_len)
else:
strt_idx = self.start + (self.current_iter * self.batch_size * self.stride)
seq_start = strt_idx + (batch_idx * self.stride)
idcs = np.arange(seq_start, seq_start + self.seq_len) % self.ndata
for key in self.data_arrays.keys():
self.samples[key][batch_idx] = self.data_arrays[key][idcs]
self.current_iter += 1
if self.reverse_target:
self.samples[self.tgt_key][:] = self.samples[self.tgt_key][:, ::-1]
if self.get_prev_target:
self.samples['prev_tgt'] = np.roll(self.samples[self.tgt_key], shift=1, axis=1)
if self.include_iteration is True:
self.samples['iteration'] = self.index
yield self.samples
评论列表
文章目录