beam_search_decoder.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:neuralmonkey 作者: ufal 项目源码 文件源码
def _decoding_loop(self) -> BeamSearchOutput:
        # collect attention objects
        beam_body = self.get_body()

        initial_loop_state = self.get_initial_loop_state()

        def cond(*args) -> tf.Tensor:
            bsls = BeamSearchLoopState(*args)
            return tf.less(
                bsls.decoder_loop_state.feedables.step - 1, self._max_steps)

        # First step has to be run manually because while_loop needs the same
        # shapes between steps and the first beam state is not beam-sized, but
        # just a single state.
        #
        # When running ensembles, we want to provide
        # ensembled logprobs to the beam_body before manually running
        # the first step
        next_bs_loop_state = tf.cond(
            cond(*initial_loop_state),
            lambda: beam_body(*initial_loop_state),
            lambda: initial_loop_state)

        final_state = tf.while_loop(cond, beam_body, next_bs_loop_state)
        dec_loop_state = final_state.decoder_loop_state
        bs_state = final_state.bs_state

        scores = final_state.bs_output.scores.stack()
        parent_ids = final_state.bs_output.parent_ids.stack()
        token_ids = final_state.bs_output.token_ids.stack()

        # TODO: return att_loop_states properly
        return BeamSearchOutput(
            last_search_step_output=SearchStepOutput(
                scores=scores,
                parent_ids=parent_ids,
                token_ids=token_ids),
            last_dec_loop_state=dec_loop_state.feedables,
            last_search_state=bs_state,
            attention_loop_states=[])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号