Dataset.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:bandit-nmt 作者: khanhptnk 项目源码 文件源码
def __getitem__(self, index):
        assert index < self.numBatches, "%d > %d" % (index, self.numBatches)
        srcBatch, lengths = self._batchify(self.src[index*self.batchSize:(index+1)*self.batchSize],
            include_lengths=True)

        tgtBatch = self._batchify(self.tgt[index*self.batchSize:(index+1)*self.batchSize])

        # within batch sort by decreasing length.
        indices = range(len(srcBatch))
        batch = zip(indices, srcBatch, tgtBatch)
        batch, lengths = zip(*sorted(zip(batch, lengths), key=lambda x: -x[1]))
        indices, srcBatch, tgtBatch = zip(*batch)

        def wrap(b):
            b = torch.stack(b, 0).t().contiguous()
            if self.cuda:
                b = b.cuda()
            b = Variable(b, volatile=self.eval)
            return b

        return (wrap(srcBatch), lengths), wrap(tgtBatch), indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号