seq2seq.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:TOHO_AI 作者: re53min 项目源码 文件源码
def test_decode(self, start, eos, limit):
        output = []
        y = chainer.Variable(np.array([[start]], dtype=np.int32))

        for i in range(limit):
            decode0 = self.output_embed(y)
            decode1 = self.decode1(decode0)
            decode2 = self.decode2(decode1)
            z = self.output(decode2)
            prob = F.softmax(z)

            index = np.argmax(cuda.to_cpu(prob.data))

            if index == eos:
                break
            output.append(index)
            y = chainer.Variable(np.array([index], dtype=np.int32))
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号