def translate(self, xs, max_length=100):
batch = len(xs)
with chainer.no_backprop_mode(), chainer.using_config('train', False):
xs_f = xs
xs_b = [x[::-1] for x in xs]
exs_f = sequence_embed(self.embed_x, xs_f)
exs_b = sequence_embed(self.embed_x, xs_b)
fx, _ = self.encoder_f(None, exs_f)
bx, _ = self.encoder_b(None, exs_b)
h = F.concat([fx, bx], axis=2)
ys = self.xp.full(batch, EOS, 'i')
result = []
for i in range(max_length):
eys = self.embed_y(ys)
eys = chainer.functions.split_axis(eys, batch, 0)
h, ys = self.decoder(h, eys)
cys = chainer.functions.concat(ys, axis=0)
wy = self.W(cys)
ys = self.xp.argmax(wy.data, axis=1).astype('i')
result.append(ys)
result = cuda.to_cpu(self.xp.stack(result).T)
# Remove EOS taggs
outs = []
for y in result:
inds = np.argwhere(y == EOS)
if len(inds) > 0:
y = y[:inds[0, 0]]
outs.append(y)
return outs
评论列表
文章目录