def create_padded_batch(max_length=100, max_tokens=None,
batch_first=False, sort=False,
pack=False, augment=False):
def collate(seqs, sort=sort, pack=pack):
if not torch.is_tensor(seqs[0]):
if sort or pack: # packing requires a sorted batch by length
# sort by the first set
seqs.sort(key=lambda x: len(x[0]), reverse=True)
# TODO: for now, just the first input will be packed
return tuple([collate(s, sort=False, pack=pack and (i == 0))
for i, s in enumerate(zip(*seqs))])
return batch_sequences(seqs, max_length=max_length,
max_tokens=max_tokens,
batch_first=batch_first,
sort=False, pack=pack, augment=augment)
return collate
评论列表
文章目录