def translate(self, sentence: np.ndarray, max_length: int = 30) -> List[int]:
with chainer.no_backprop_mode(), chainer.using_config('train', False):
sentence = sentence[::-1]
embedded_xs = self._embed_input(sentence)
hidden_states, cell_states, attentions = self._encoder(None, None, [embedded_xs])
wid = EOS
result = []
for i in range(max_length):
output, hidden_states, cell_states = \
self._translate_one_word(wid, hidden_states, cell_states, attentions)
wid = np.argmax(output.data)
if wid == EOS:
break
result.append(wid)
return result
评论列表
文章目录