JA_Parallel_BiGRU.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:NANHM-for-GEC 作者: shinochin 项目源码 文件源码
def translate(self, xs, max_length=100):
        print("Now translating")
        batch = len(xs)
        print("batch",batch)
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            wxs = [np.array([source_word_ids.get(w, UNK) for w in x], dtype=np.int32) for x in xs]
            wx_len = [len(wx) for wx in wxs]
            wx_section = np.cumsum(wx_len[:-1])
            valid_wx_section = np.insert(wx_section, 0, 0)
            cxs = [np.array([source_char_ids.get(c, UNK) for c in list("".join(x))], dtype=np.int32) for x in xs]

            wexs = sequence_embed(self.embed_xw, wxs)
            cexs = sequence_embed(self.embed_xc, cxs)

            wexs_f = wexs
            wexs_b = [wex[::-1] for wex in wexs]
            cexs_f = cexs
            cexs_b = [cex[::-1] for cex in cexs]

            _, hfw = self.encoder_fw(None, wexs_f)
            h1, hbw = self.encoder_bw(None, wexs_b)
            _, hfc = self.encoder_fc(None, cexs_f)
            h2, hbc = self.encoder_bc(None, cexs_b)

            hbw = [F.get_item(h, range(len(h))[::-1]) for h in hbw]
            hbc = [F.get_item(h, range(len(h))[::-1]) for h in hbc]
            htw = list(map(lambda x,y: F.concat([x, y], axis=1), hfw, hbw))
            htc = list(map(lambda x,y: F.concat([x, y], axis=1), hfc, hbc))
            ht = list(map(lambda x,y: F.concat([x, y], axis=0), htw, htc))

            ys = self.xp.full(batch, EOS, 'i')
            result = []
            h = F.concat([h1, h2], axis=2)
            for i in range(max_length):
                eys = self.embed_y(ys)
                eys = chainer.functions.split_axis(eys, batch, 0)
                h_list, h_bar_list, c_s_list, z_s_list = self.decoder(h, ht, eys)
                cys = chainer.functions.concat(h_list, axis=0)
                wy = self.W(cys)
                ys = self.xp.argmax(wy.data, axis=1).astype('i')
                result.append(ys)
                h = F.transpose_sequence(h_list)[-1]
                h = F.reshape(h, (self.n_layers, h.shape[0], h.shape[1]))

        result = cuda.to_cpu(self.xp.stack(result).T)

        # Remove EOS taggs
        outs = []
        for y in result:
            inds = np.argwhere(y == EOS)
            if len(inds) > 0:
                y = y[:inds[0, 0]]
            outs.append(y)
        return outs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号