Att_BiGRU.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:NANHM-for-GEC 作者: shinochin 项目源码 文件源码
def translate(self, xs, max_length=100):
        batch = len(xs)
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            xs_f = xs
            xs_b = [x[::-1] for x in xs]
            exs_f = sequence_embed(self.embed_x, xs_f)
            exs_b = sequence_embed(self.embed_x, xs_b)
            _, hf = self.encoder_f(None, exs_f)
            _, hb = self.encoder_b(None, exs_b)
            ht = list(map(lambda x,y: F.concat([x, y], axis=1), hf, hb))
            ys = self.xp.full(batch, EOS, 'i')
            result = []
            for i in range(max_length):
                eys = self.embed_y(ys)
                eys = chainer.functions.split_axis(eys, batch, 0)
                h_list, h_bar_list, c_s_list, z_s_list = self.decoder(None, ht, eys)
                cys = chainer.functions.concat(h_list, axis=0)
                wy = self.W(cys)
                ys = self.xp.argmax(wy.data, axis=1).astype('i')
                result.append(ys)

        result = cuda.to_cpu(self.xp.stack(result).T)

        # Remove EOS taggs
        outs = []
        for y in result:
            inds = np.argwhere(y == EOS)
            if len(inds) > 0:
                y = y[:inds[0, 0]]
            outs.append(y)
        return outs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号