def decode(self, seq, pos):
def length_penalty(step, len_penalty_w=1.):
return (torch.log(self.torch.FloatTensor([5 + step])) - torch.log(self.torch.FloatTensor([6])))*len_penalty_w
top_seqs = [([BOS], 0)] * self.beam_size
enc_outputs = self.model.enc(seq, pos)
seq_beam = Variable(seq.data.repeat(self.beam_size, 1))
enc_outputs_beam = [Variable(enc_output.data.repeat(self.beam_size, 1, 1)) for enc_output in enc_outputs]
input_data = self.init_input()
input_pos = torch.arange(1, 2).unsqueeze(0)
input_pos = input_pos.repeat(self.beam_size, 1)
input_pos = Variable(input_pos.long(), volatile=True)
for step in range(1, self.args.max_word_len+1):
if self.cuda:
input_pos = input_pos.cuda()
input_data = input_data.cuda()
dec_output = self.model.dec(enc_outputs_beam,
seq_beam, input_data, input_pos)
dec_output = dec_output[:, -1, :] # word level feature
out = F.log_softmax(self.model.linear(dec_output))
lp = length_penalty(step)
top_seqs, all_done, un_dones = self.beam_search(out.data+lp, top_seqs)
if all_done: break
input_data = self.update_input(top_seqs)
input_pos, src_seq_beam, enc_outputs_beam = self.update_state(step+1, seq, enc_outputs, un_dones)
tgts = []
for seq in top_seqs:
cor_idxs, score = seq
cor_idxs = cor_idxs[1: -1]
tgts += [(" ".join([self.src_idx2word[idx] for idx in cor_idxs]), score)]
return tgts
评论列表
文章目录