def test_batches_from_two_sets():
data1 = np.array(['a', 'b'])
data2 = np.array(['c', 'd', 'e'])
batch_generator = combine_batches(
eternal_batches(data1, batch_size=1),
eternal_batches(data2, batch_size=2)
)
first_six_batches = list(islice(batch_generator, 6))
assert [len(batch) for batch in first_six_batches] == [3, 3, 3, 3, 3, 3]
batch_portions1 = [batch[:1] for batch in first_six_batches]
batch_portions2 = [batch[1:] for batch in first_six_batches]
returned1 = np.concatenate(batch_portions1)
returned2 = np.concatenate(batch_portions2)
epochs1 = np.split(returned1, 3)
epochs2 = np.split(returned2, 4)
assert all(sorted(items) == ['a', 'b'] for items in epochs1)
assert all(sorted(items) == ['c', 'd', 'e'] for items in epochs2)
评论列表
文章目录