seq2seq.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:TOHO_AI 作者: re53min 项目源码 文件源码
def decode(self, sentences):
        # sentences = Variable(np.array([sentences], dtype=np.int32), volatile=False)
        loss = Variable(np.zeros((), dtype=np.float32))
        n_words = len(sentences)-1

        for word, t in zip(sentences, sentences[1:]):
            # print('??:{}, ??:{}'.format(word,t))
            word = Variable(np.array([[word]], dtype=np.int32))
            t = Variable(np.array([t], dtype=np.int32))
            decode0 = self.output_embed(word)
            decode1 = self.decode1(decode0)
            decode2 = self.decode2(decode1)
            z = self.output(decode2)

            loss += F.softmax_cross_entropy(z, t)

        return loss, n_words
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号