data.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:loop 作者: facebookresearch 项目源码 文件源码
def __init__(self, src, trgt, spkr, seq_len):
        self.seq_len = seq_len
        self.start = True

        self.speakers = spkr
        self.srcBatch = src[0]
        self.srcLenths = src[1]

        # split batch
        self.tgtBatch = list(torch.split(trgt[0], self.seq_len, 0))
        self.tgtBatch.reverse()
        self.len = len(self.tgtBatch)

        # split length list
        batch_seq_len = len(self.tgtBatch)
        self.tgtLenths = [self.split_length(l, batch_seq_len) for l in trgt[1]]
        self.tgtLenths = torch.stack(self.tgtLenths)
        self.tgtLenths = list(torch.split(self.tgtLenths, 1, 1))
        self.tgtLenths = [x.squeeze() for x in self.tgtLenths]
        self.tgtLenths.reverse()

        assert len(self.tgtLenths) == len(self.tgtBatch)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号