def translate(self, xs, max_length=100):
batch = len(xs)
with chainer.no_backprop_mode():
with chainer.using_config('train', False):
result = []
ys = self.xp.zeros(batch, 'i')
eys = self.embed_y(ys)
eys = chainer.functions.split_axis(
eys, batch, 0, force_tuple=True)
# Receive hidden stats from encoder process.
h, c, ys, _ = self.mn_decoder(eys)
cys = chainer.functions.concat(ys, axis=0)
wy = self.W(cys)
ys = self.xp.argmax(wy.data, axis=1).astype('i')
result.append(ys)
# Recursively decode using the previously predicted token.
for i in range(1, max_length):
eys = self.embed_y(ys)
eys = chainer.functions.split_axis(
eys, batch, 0, force_tuple=True)
# Non-MN RNN link can be accessed via `actual_rnn`.
h, c, ys = self.mn_decoder.actual_rnn(h, c, eys)
cys = chainer.functions.concat(ys, axis=0)
wy = self.W(cys)
ys = self.xp.argmax(wy.data, axis=1).astype('i')
result.append(ys)
result = cuda.to_cpu(self.xp.stack(result).T)
# Remove EOS taggs
outs = []
for y in result:
inds = numpy.argwhere(y == 0)
if len(inds) > 0:
y = y[:inds[0, 0]]
outs.append(y)
return outs
评论列表
文章目录