def generate(self, src, sampled=False):
dynet.renew_cg()
embedding = self.embed_seq(src)
encoding = self.encode_seq(embedding)[-1]
w = dynet.parameter(self.decoder_w)
b = dynet.parameter(self.decoder_b)
s = self.dec_lstm.initial_state().add_input(encoding)
out = []
for _ in range(5*len(src)):
out_vector = dynet.affine_transform([b, w, s.output()])
probs = dynet.softmax(out_vector)
selection = np.argmax(probs.value())
out.append(self.tgt_vocab[selection])
if out[-1].s == self.tgt_vocab.END_TOK: break
embed_vector = self.tgt_lookup[selection]
s = s.add_input(embed_vector)
return out
评论列表
文章目录