sequence2sequence.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:lang-reps 作者: chaitanyamalaviya 项目源码 文件源码
def decode(self, encoding, input, output):
        """
        Single training example decoding function
        :param encoding: last hidden state from encoder
        :param input: source sentence
        :param output: target sentence
        :return: loss value
        """

        src_toks = [self.src_vocab[tok] for tok in input]
        tgt_toks = [self.tgt_vocab[tok] for tok in output]

        w = dynet.parameter(self.decoder_w)
        b = dynet.parameter(self.decoder_b)
        s = self.dec_lstm.initial_state().add_input(encoding)
        loss = []

        sent = []
        for tok in tgt_toks:
            out_vector = dynet.affine_transform([b, w, s.output()])
            probs = dynet.softmax(out_vector)
            cross_ent_loss = - dynet.log(dynet.pick(probs, tok.i))
            loss.append(cross_ent_loss)
            embed_vector = self.tgt_lookup[tok.i]
            s = s.add_input(embed_vector)

        loss = dynet.esum(loss)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号