transform.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def decode(self, seq, pos):
        def length_penalty(step, len_penalty_w=1.):
            return (torch.log(self.torch.FloatTensor([5 + step])) - torch.log(self.torch.FloatTensor([6])))*len_penalty_w

        top_seqs = [([BOS], 0)] * self.beam_size

        enc_outputs = self.model.enc(seq, pos)
        seq_beam = Variable(seq.data.repeat(self.beam_size, 1))
        enc_outputs_beam = [Variable(enc_output.data.repeat(self.beam_size, 1, 1)) for enc_output in enc_outputs]

        input_data = self.init_input()
        input_pos = torch.arange(1, 2).unsqueeze(0)
        input_pos = input_pos.repeat(self.beam_size, 1)
        input_pos = Variable(input_pos.long(), volatile=True)

        for step in range(1, self.args.max_word_len+1):
            if self.cuda:
                input_pos = input_pos.cuda()
                input_data = input_data.cuda()

            dec_output = self.model.dec(enc_outputs_beam,
                            seq_beam, input_data, input_pos)
            dec_output = dec_output[:, -1, :] # word level feature
            out = F.log_softmax(self.model.linear(dec_output))
            lp = length_penalty(step)

            top_seqs, all_done, un_dones = self.beam_search(out.data+lp, top_seqs)

            if all_done: break
            input_data = self.update_input(top_seqs)
            input_pos, src_seq_beam, enc_outputs_beam = self.update_state(step+1, seq, enc_outputs, un_dones)

        tgts = []
        for seq in top_seqs:
            cor_idxs, score = seq
            cor_idxs = cor_idxs[1: -1]
            tgts += [(" ".join([self.src_idx2word[idx] for idx in cor_idxs]), score)]
        return tgts
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号