def iterbatches(arrays, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
arrays = tuple(map(np.asarray, arrays))
n = arrays[0].shape[0]
assert all(a.shape[0] == n for a in arrays[1:])
inds = np.arange(n)
if shuffle: np.random.shuffle(inds)
sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches
for batch_inds in np.array_split(inds, sections):
if include_final_partial_batch or len(batch_inds) == batch_size:
yield tuple(a[batch_inds] for a in arrays)
评论列表
文章目录