def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return _make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False)
评论列表
文章目录