BiGRU.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:NANHM-for-GEC 作者: shinochin 项目源码 文件源码
def translate(self, xs, max_length=100):
        batch = len(xs)
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            xs_f = xs
            xs_b = [x[::-1] for x in xs]
            exs_f = sequence_embed(self.embed_x, xs_f)
            exs_b = sequence_embed(self.embed_x, xs_b)
            fx, _ = self.encoder_f(None, exs_f)
            bx, _ = self.encoder_b(None, exs_b)
            h = F.concat([fx, bx], axis=2)
            ys = self.xp.full(batch, EOS, 'i')
            result = []
            for i in range(max_length):
                eys = self.embed_y(ys)
                eys = chainer.functions.split_axis(eys, batch, 0)
                h, ys = self.decoder(h, eys)
                cys = chainer.functions.concat(ys, axis=0)
                wy = self.W(cys)
                ys = self.xp.argmax(wy.data, axis=1).astype('i')
                result.append(ys)

        result = cuda.to_cpu(self.xp.stack(result).T)

        # Remove EOS taggs
        outs = []
        for y in result:
            inds = np.argwhere(y == EOS)
            if len(inds) > 0:
                y = y[:inds[0, 0]]
            outs.append(y)
        return outs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号