data_loader.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def __next__(self):
        def to_longest(insts):
            inst_data_tensor = Variable(torch.from_numpy(insts))
            if self.cuda:
                inst_data_tensor = inst_data_tensor.cuda()
            return inst_data_tensor

        if self._step == self._stop_step:
            self._step = 0
            raise StopIteration()

        _start = self._step*self._batch_size
        _bsz = self._batch_size
        self._step += 1
        data = to_longest(self._src_sents[_start: _start+_bsz])
        label = to_longest(self._label[_start: _start+_bsz])
        return data, label.contiguous().view(-1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号