def __next__(self):
def to_longest(insts):
inst_data_tensor = Variable(torch.from_numpy(insts))
if self.cuda:
inst_data_tensor = inst_data_tensor.cuda()
return inst_data_tensor
if self._step == self._stop_step:
self._step = 0
raise StopIteration()
_start = self._step*self._batch_size
_bsz = self._batch_size
self._step += 1
enc_input = to_longest(self._enc_sents[_start: _start+_bsz])
dec_input = to_longest(self._dec_sents[_start: _start+_bsz])
label = to_longest(self._label[_start: _start+_bsz])
return enc_input, dec_input, label
评论列表
文章目录