def argmax(self, z, max_len):
# local variables
eos, bos = self.src_dict.get_eos(), self.src_dict.get_bos()
batch = z.size(0)
# output variables
scores, preds, mask = 0, [], z.data.new(batch).long() + 1
# model inputs
hidden = self.decoder.init_hidden_for(z)
prev = Variable(z.data.new(batch).zero_().long() + bos, volatile=True)
for _ in range(max_len):
prev_emb = self.embeddings(prev).squeeze(0)
dec_out, hidden = self.decoder(prev_emb, hidden, z=z)
dec_out = self.project(dec_out.unsqueeze(0))
score, pred = dec_out.max(1)
scores += score.squeeze().data
preds.append(pred.squeeze().data)
prev = pred
mask = mask * (pred.squeeze().data[0] != eos)
if mask.int().sum() == 0:
break
return scores.tolist(), torch.stack(preds).transpose(0, 1).tolist()
评论列表
文章目录