def translate(self, inputs, max_length):
targets, init_states = self.initialize(inputs, eval=True)
emb, output, hidden, context = init_states
preds = []
batch_size = targets.size(1)
num_eos = targets[0].data.byte().new(batch_size).zero_()
for i in range(max_length):
output, hidden = self.decoder.step(emb, output, hidden, context)
logit = self.generator(output)
pred = logit.max(1)[1].view(-1).data
preds.append(pred)
# Stop if all sentences reach EOS.
num_eos |= (pred == lib.Constants.EOS)
if num_eos.sum() == batch_size: break
emb = self.decoder.word_lut(Variable(pred))
preds = torch.stack(preds)
return preds
评论列表
文章目录