def translate_with_beam_search(self, sentence: np.ndarray, max_length: int = 30, beam_width=3) -> List[int]:
with chainer.no_backprop_mode(), chainer.using_config('train', False):
sentence = sentence[::-1]
embedded_xs = self._embed_input(sentence)
hidden_states, cell_states, attentions = self._encoder(None, None, [embedded_xs])
heaps = [[] for _ in range(max_length + 1)]
heaps[0].append((0, [EOS], hidden_states, cell_states)) # (score, translation, hidden_states, cell_states)
solution = []
solution_score = 1e8
for i in range(max_length):
heaps[i] = sorted(heaps[i], key=lambda t: t[0])[:beam_width]
for score, translation, i_hidden_states, i_cell_states in heaps[i]:
wid = translation[-1]
output, new_hidden_states, new_cell_states = \
self._translate_one_word(wid, i_hidden_states, i_cell_states, attentions)
for next_wid in np.argsort(output.data)[::-1]:
if output.data[next_wid] < 1e-6:
break
next_score = score - np.log(output.data[next_wid])
if next_score > solution_score:
break
next_translation = translation + [next_wid]
next_item = (next_score, next_translation, new_hidden_states, new_cell_states)
if next_wid == EOS:
if next_score < solution_score:
solution = translation[1:] # [1:] drops first EOS
solution_score = next_score
else:
heaps[i + 1].append(next_item)
return solution
评论列表
文章目录