sequence2sequence.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:lang-reps 作者: chaitanyamalaviya 项目源码 文件源码
def generate(self, src, sampled=False):
        dynet.renew_cg()

        embedding = self.embed_seq(src)
        encoding = self.encode_seq(embedding)[-1]

        w = dynet.parameter(self.decoder_w)
        b = dynet.parameter(self.decoder_b)

        s = self.dec_lstm.initial_state().add_input(encoding)

        out = []
        for _ in range(5*len(src)):
            out_vector = dynet.affine_transform([b, w, s.output()])
            probs = dynet.softmax(out_vector)
            selection = np.argmax(probs.value())
            out.append(self.tgt_vocab[selection])
            if out[-1].s == self.tgt_vocab.END_TOK: break
            embed_vector = self.tgt_lookup[selection]
            s = s.add_input(embed_vector)
        return out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号