data_loader.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def __next__(self):
        def pad_to_longest(insts, max_len):
            inst_data = np.array([inst + [const.PAD] * (max_len - len(inst)) for inst in insts])

            inst_data_tensor = Variable(torch.from_numpy(inst_data), volatile=self.evaluation)
            if self.cuda:
                inst_data_tensor = inst_data_tensor.cuda()
            return inst_data_tensor

        if self._step == self._stop_step:
            self._step = 0
            raise StopIteration()

        _start = self._step*self._batch_size
        _bsz = self._batch_size
        self._step += 1
        data = pad_to_longest(self._src_sents[_start:_start+_bsz], self._max_len)
        label = Variable(torch.from_numpy(self._label[_start:_start+_bsz]),
                    volatile=self.evaluation)
        if self.cuda:
            label = label.cuda()

        return data, label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号