def beam_search_generate(self, src_seq, beam_n=5):
dynet.renew_cg()
embedded = self.embed_seq(src_seq)
input_vectors = self.encode_seq(embedded)
w = dynet.parameter(self.decoder_w)
b = dynet.parameter(self.decoder_b)
s = self.dec_lstm.initial_state()
s = s.add_input(input_vectors[-1])
beams = [{"state": s,
"out": [],
"err": 0}]
completed_beams = []
while len(completed_beams) < beam_n:
potential_beams = []
for beam in beams:
if len(beam["out"]) > 0:
embed_vector = self.tgt_lookup[beam["out"][-1].i]
s = beam["state"].add_input(embed_vector)
out_vector = dynet.affine_transform([b, w, s.output()])
probs = dynet.softmax(out_vector)
probs = probs.vec_value()
for potential_next_i in range(len(probs)):
potential_beams.append({"state": s,
"out": beam["out"]+[self.tgt_vocab[potential_next_i]],
"err": beam["err"]-math.log(probs[potential_next_i])})
potential_beams.sort(key=lambda x:x["err"])
beams = potential_beams[:beam_n-len(completed_beams)]
completed_beams = completed_beams+[beam for beam in beams if beam["out"][-1] == self.tgt_vocab.END_TOK
or len(beam["out"]) > 5*len(src_seq)]
beams = [beam for beam in beams if beam["out"][-1] != self.tgt_vocab.END_TOK
and len(beam["out"]) <= 5*len(src_seq)]
completed_beams.sort(key=lambda x:x["err"])
return [beam["out"] for beam in completed_beams]
评论列表
文章目录