seq2seq_mp1.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:chainermn 作者: chainer 项目源码 文件源码
def translate(self, xs, max_length=100):
        batch = len(xs)
        with chainer.no_backprop_mode():
            with chainer.using_config('train', False):
                result = []
                ys = self.xp.zeros(batch, 'i')
                eys = self.embed_y(ys)
                eys = chainer.functions.split_axis(
                    eys, batch, 0, force_tuple=True)

                # Receive hidden stats from encoder process.
                h, c, ys, _ = self.mn_decoder(eys)

                cys = chainer.functions.concat(ys, axis=0)
                wy = self.W(cys)
                ys = self.xp.argmax(wy.data, axis=1).astype('i')
                result.append(ys)

                # Recursively decode using the previously predicted token.
                for i in range(1, max_length):
                    eys = self.embed_y(ys)
                    eys = chainer.functions.split_axis(
                        eys, batch, 0, force_tuple=True)
                    # Non-MN RNN link can be accessed via `actual_rnn`.
                    h, c, ys = self.mn_decoder.actual_rnn(h, c, eys)
                    cys = chainer.functions.concat(ys, axis=0)
                    wy = self.W(cys)
                    ys = self.xp.argmax(wy.data, axis=1).astype('i')
                    result.append(ys)

        result = cuda.to_cpu(self.xp.stack(result).T)

        # Remove EOS taggs
        outs = []
        for y in result:
            inds = numpy.argwhere(y == 0)
            if len(inds) > 0:
                y = y[:inds[0, 0]]
            outs.append(y)
        return outs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号