encoder_decoder.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:workspace 作者: nojima 项目源码 文件源码
def translate_with_beam_search(self, sentence: np.ndarray, max_length: int = 30, beam_width=3) -> List[int]:
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            sentence = sentence[::-1]

            embedded_xs = self._embed_input(sentence)
            hidden_states, cell_states, attentions = self._encoder(None, None, [embedded_xs])

            heaps = [[] for _ in range(max_length + 1)]
            heaps[0].append((0, [EOS], hidden_states, cell_states))  # (score, translation, hidden_states, cell_states)

            solution = []
            solution_score = 1e8

            for i in range(max_length):
                heaps[i] = sorted(heaps[i], key=lambda t: t[0])[:beam_width]

                for score, translation, i_hidden_states, i_cell_states in heaps[i]:
                    wid = translation[-1]
                    output, new_hidden_states, new_cell_states = \
                        self._translate_one_word(wid, i_hidden_states, i_cell_states, attentions)

                    for next_wid in np.argsort(output.data)[::-1]:
                        if output.data[next_wid] < 1e-6:
                            break
                        next_score = score - np.log(output.data[next_wid])
                        if next_score > solution_score:
                            break
                        next_translation = translation + [next_wid]
                        next_item = (next_score, next_translation, new_hidden_states, new_cell_states)

                        if next_wid == EOS:
                            if next_score < solution_score:
                                solution = translation[1:]  # [1:] drops first EOS
                                solution_score = next_score
                        else:
                            heaps[i + 1].append(next_item)

            return solution
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号