data.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:ParlAI 作者: facebookresearch 项目源码 文件源码
def collate(samples, pad_idx, eos_idx):

        def merge(key, left_pad, move_eos_to_beginning=False):
            return LanguagePairDataset.collate_tokens(
                [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)

        return {
            'id': torch.LongTensor([s['id'].item() for s in samples]),
            'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
            # we create a shifted version of targets for feeding the previous
            # output token(s) into the next decoder step
            'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
                                  move_eos_to_beginning=True),
            'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
            'ntokens': sum(len(s['target']) for s in samples),
        }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号