def __next__(self):
def to_longest(insts):
inst_data_tensor = Variable(torch.from_numpy(insts))
if self.cuda:
inst_data_tensor = inst_data_tensor.cuda()
return inst_data_tensor
if self._step == self._stop_step:
self._step = 0
raise StopIteration()
_start = self._step*self._batch_size
_bsz = self._batch_size
self._step += 1
data = to_longest(self._src_sents[_start: _start+_bsz])
label = to_longest(self._label[_start: _start+_bsz])
return data, label.contiguous().view(-1)
评论列表
文章目录