def __iter__(self):
while self.index < self.total_iterations:
idx = self.index % self.nbatches
self.index += 1
dict = {}
for k, x in viewitems(self.data_arrays):
if k == 'inp_txt' or k == 'teacher_tgt':
dict[k] = np.squeeze(x[:, idx:(idx + 1), :, :])
else:
dict[k] = np.squeeze(x[:, idx:(idx + 1), :])
yield dict
评论列表
文章目录