multi_language.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:seq2seq.pytorch 作者: eladhoffer 项目源码 文件源码
def create_padded_batch(max_length=100, max_tokens=None,
                        batch_first=False, sort=False,
                        pack=False, augment=False):
    def collate(seqs, sort=sort, pack=pack):
        if not torch.is_tensor(seqs[0]):
            if sort or pack:  # packing requires a sorted batch by length
                # sort by the first set
                seqs.sort(key=lambda x: len(x[0]), reverse=True)
            # TODO: for now, just the first input will be packed
            return tuple([collate(s, sort=False, pack=pack and (i == 0))
                          for i, s in enumerate(zip(*seqs))])
        return batch_sequences(seqs, max_length=max_length,
                               max_tokens=max_tokens,
                               batch_first=batch_first,
                               sort=False, pack=pack, augment=augment)
    return collate
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号