def translate(self, xs, max_length=100):
batch = len(xs)
with chainer.no_backprop_mode(), chainer.using_config('train', False):
xs_f = xs
xs_b = [x[::-1] for x in xs]
exs_f = sequence_embed(self.embed_x, xs_f)
exs_b = sequence_embed(self.embed_x, xs_b)
_, hf = self.encoder_f(None, exs_f)
_, hb = self.encoder_b(None, exs_b)
ht = list(map(lambda x,y: F.concat([x, y], axis=1), hf, hb))
ys = self.xp.full(batch, EOS, 'i')
result = []
for i in range(max_length):
eys = self.embed_y(ys)
eys = chainer.functions.split_axis(eys, batch, 0)
h_list, h_bar_list, c_s_list, z_s_list = self.decoder(None, ht, eys)
cys = chainer.functions.concat(h_list, axis=0)
wy = self.W(cys)
ys = self.xp.argmax(wy.data, axis=1).astype('i')
result.append(ys)
result = cuda.to_cpu(self.xp.stack(result).T)
# Remove EOS taggs
outs = []
for y in result:
inds = np.argwhere(y == EOS)
if len(inds) > 0:
y = y[:inds[0, 0]]
outs.append(y)
return outs
评论列表
文章目录