data_loader.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def get_batch(self, i, evaluation=False):
        def pad_to_longest(insts, max_len):
            inst_data = np.array([inst + [const.PAD] * (max_len - len(inst)) for inst in insts])
            inst_data_tensor = Variable(torch.from_numpy(inst_data), volatile=evaluation)
            if self.cuda:
                inst_data_tensor = inst_data_tensor.cuda()
            return inst_data_tensor

        bsz = min(self._batch_size, self._sents_size-1-i)

        src = pad_to_longest(self._src_sents[i:i+bsz], self._max_src)
        tgt = pad_to_longest(self._tgt_sents[i:i+bsz], self._max_tgt)
        label = Variable(torch.from_numpy(self._label[i:i+bsz]), volatile=evaluation)
        if self.cuda:
                label = label.cuda()

        return src, tgt, label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号