def collate(samples, pad_idx, eos_idx):
def merge(key, left_pad, move_eos_to_beginning=False):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'ntokens': sum(len(s['target']) for s in samples),
}
评论列表
文章目录