sequence2sequence.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:lang-reps 作者: chaitanyamalaviya 项目源码 文件源码
def decode(self, input_vectors, output):

        tgt_toks = [self.tgt_vocab[tok] for tok in output]

        w = dynet.parameter(self.decoder_w)
        b = dynet.parameter(self.decoder_b)

        s = self.dec_lstm.initial_state()
        s = s.add_input(dynet.concatenate([
                                            input_vectors[-1],
                                            dynet.vecInput(self.args.hidden_dim)
                                          ]))
        loss = []
        for tok in tgt_toks:
            out_vector = dynet.affine_transform([b, w, s.output()])
            probs = dynet.softmax(out_vector)
            loss.append(-dynet.log(dynet.pick(probs, tok.i)))
            embed_vector = self.tgt_lookup[tok.i]
            attn_vector = self.attend(input_vectors, s)
            inp = dynet.concatenate([embed_vector, attn_vector])
            s = s.add_input(inp)

        loss = dynet.esum(loss)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号