def __getitem__(self, index):
assert index < self.numBatches, "%d > %d" % (index, self.numBatches)
srcBatch, lengths = self._batchify(self.src[index*self.batchSize:(index+1)*self.batchSize],
include_lengths=True)
tgtBatch = self._batchify(self.tgt[index*self.batchSize:(index+1)*self.batchSize])
# within batch sort by decreasing length.
indices = range(len(srcBatch))
batch = zip(indices, srcBatch, tgtBatch)
batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1]))
indices, srcBatch, tgtBatch = zip(*batch)
def wrap(b):
b = torch.stack(b, 0).t().contiguous()
if self.cuda:
b = b.cuda()
b = Variable(b, volatile=self.eval)
return b
return (wrap(srcBatch), lengths), wrap(tgtBatch), indices
评论列表
文章目录