def __iter__(self):
lengths = np.array(
[(-l[0], -l[1], np.random.random()) for l in self.lengths],
dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)]
)
indices = np.argsort(lengths, order=('l1', 'l2', 'rand'))
batches = [indices[i:i + self.batch_size]
for i in range(0, len(indices), self.batch_size)]
if self.shuffle:
np.random.shuffle(batches)
return iter([i for batch in batches for i in batch])
评论列表
文章目录