def translate(self, xs, max_length=100):
print("Now translating")
batch = len(xs)
print("batch",batch)
with chainer.no_backprop_mode(), chainer.using_config('train', False):
wxs = [np.array([source_word_ids.get(w, UNK) for w in x], dtype=np.int32) for x in xs]
wx_len = [len(wx) for wx in wxs]
wx_section = np.cumsum(wx_len[:-1])
valid_wx_section = np.insert(wx_section, 0, 0)
cxs = [np.array([source_char_ids.get(c, UNK) for c in list("".join(x))], dtype=np.int32) for x in xs]
wexs = sequence_embed(self.embed_xw, wxs)
cexs = sequence_embed(self.embed_xc, cxs)
wexs_f = wexs
wexs_b = [wex[::-1] for wex in wexs]
cexs_f = cexs
cexs_b = [cex[::-1] for cex in cexs]
_, hfw = self.encoder_fw(None, wexs_f)
h1, hbw = self.encoder_bw(None, wexs_b)
_, hfc = self.encoder_fc(None, cexs_f)
h2, hbc = self.encoder_bc(None, cexs_b)
hbw = [F.get_item(h, range(len(h))[::-1]) for h in hbw]
hbc = [F.get_item(h, range(len(h))[::-1]) for h in hbc]
htw = list(map(lambda x,y: F.concat([x, y], axis=1), hfw, hbw))
htc = list(map(lambda x,y: F.concat([x, y], axis=1), hfc, hbc))
ht = list(map(lambda x,y: F.concat([x, y], axis=0), htw, htc))
ys = self.xp.full(batch, EOS, 'i')
result = []
h = F.concat([h1, h2], axis=2)
for i in range(max_length):
eys = self.embed_y(ys)
eys = chainer.functions.split_axis(eys, batch, 0)
h_list, h_bar_list, c_s_list, z_s_list = self.decoder(h, ht, eys)
cys = chainer.functions.concat(h_list, axis=0)
wy = self.W(cys)
ys = self.xp.argmax(wy.data, axis=1).astype('i')
result.append(ys)
h = F.transpose_sequence(h_list)[-1]
h = F.reshape(h, (self.n_layers, h.shape[0], h.shape[1]))
result = cuda.to_cpu(self.xp.stack(result).T)
# Remove EOS taggs
outs = []
for y in result:
inds = np.argwhere(y == EOS)
if len(inds) > 0:
y = y[:inds[0, 0]]
outs.append(y)
return outs
评论列表
文章目录