sequence2sequence.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:lang-reps 作者: chaitanyamalaviya 项目源码 文件源码
def beam_search_generate(self, src_seq, beam_n=5):
        dynet.renew_cg()

        embedded = self.embed_seq(src_seq)
        input_vectors = self.encode_seq(embedded)

        w = dynet.parameter(self.decoder_w)
        b = dynet.parameter(self.decoder_b)

        s = self.dec_lstm.initial_state()
        s = s.add_input(input_vectors[-1])
        beams = [{"state":  s,
                  "out":    [],
                  "err":    0}]
        completed_beams = []
        while len(completed_beams) < beam_n:
            potential_beams = []
            for beam in beams:
                if len(beam["out"]) > 0:
                    embed_vector = self.tgt_lookup[beam["out"][-1].i]
                    s = beam["state"].add_input(embed_vector)

                out_vector = dynet.affine_transform([b, w, s.output()])
                probs = dynet.softmax(out_vector)
                probs = probs.vec_value()

                for potential_next_i in range(len(probs)):
                    potential_beams.append({"state":    s,
                                            "out":      beam["out"]+[self.tgt_vocab[potential_next_i]],
                                            "err":      beam["err"]-math.log(probs[potential_next_i])})

            potential_beams.sort(key=lambda x:x["err"])
            beams = potential_beams[:beam_n-len(completed_beams)]
            completed_beams = completed_beams+[beam for beam in beams if beam["out"][-1] == self.tgt_vocab.END_TOK
                                                                      or len(beam["out"]) > 5*len(src_seq)]
            beams = [beam for beam in beams if beam["out"][-1] != self.tgt_vocab.END_TOK
                                            and len(beam["out"]) <= 5*len(src_seq)]
        completed_beams.sort(key=lambda x:x["err"])
        return [beam["out"] for beam in completed_beams]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号