def __init__(self, src, trgt, spkr, seq_len):
self.seq_len = seq_len
self.start = True
self.speakers = spkr
self.srcBatch = src[0]
self.srcLenths = src[1]
# split batch
self.tgtBatch = list(torch.split(trgt[0], self.seq_len, 0))
self.tgtBatch.reverse()
self.len = len(self.tgtBatch)
# split length list
batch_seq_len = len(self.tgtBatch)
self.tgtLenths = [self.split_length(l, batch_seq_len) for l in trgt[1]]
self.tgtLenths = torch.stack(self.tgtLenths)
self.tgtLenths = list(torch.split(self.tgtLenths, 1, 1))
self.tgtLenths = [x.squeeze() for x in self.tgtLenths]
self.tgtLenths.reverse()
assert len(self.tgtLenths) == len(self.tgtBatch)
评论列表
文章目录