autoregressive.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:neuralmonkey 作者: ufal 项目源码 文件源码
def decoding_loop(self, train_mode: bool, sample: bool = False) -> Tuple[
            tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
        """Run the decoding while loop.

        Calls get_initial_loop_state and constructs tf.while_loop
        with the continuation criterion returned from loop_continue_criterion,
        and body function returned from get_body.

        After finishing the tf.while_loop, it calls finalize_loop
        to further postprocess the final decoder loop state (usually
        by stacking TensorArrays containing decoding histories).

        Arguments:
            train_mode: Boolean flag, telling whether this is
                a training run.
            sample: Boolean flag, telling whether we should sample
                the output symbols from the output distribution instead
                of using argmax or gold data.
        """
        initial_loop_state = self.get_initial_loop_state()
        final_loop_state = tf.while_loop(
            self.loop_continue_criterion,
            self.get_body(train_mode, sample),
            initial_loop_state)

        self.finalize_loop(final_loop_state, train_mode)

        logits = final_loop_state.histories.logits.stack()
        decoder_outputs = final_loop_state.histories.decoder_outputs.stack()
        decoded = final_loop_state.histories.outputs.stack()

        # TODO mask should include also the end symbol
        mask = final_loop_state.histories.mask.stack()

        return logits, decoder_outputs, mask, decoded
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号