decoder.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:neuralmonkey 作者: ufal 项目源码 文件源码
def get_initial_loop_state(self) -> LoopState:
        rnn_output_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True,
                                       size=0, name="decoder_outputs")
        rnn_output_ta = rnn_output_ta.write(0, self.initial_state)

        logit_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True,
                                  size=0, name="logits")

        outputs_ta = tf.TensorArray(dtype=tf.int32, dynamic_size=True,
                                    size=0, name="outputs")

        contexts = [tf.zeros([self.batch_size, a.context_vector_size])
                    for a in self.attentions]

        mask_ta = tf.TensorArray(dtype=tf.bool, dynamic_size=True,
                                 size=0, name="mask")

        attn_loop_states = [a.initial_loop_state()
                            for a in self.attentions if a is not None]

        # pylint: disable=not-callable
        rnn_feedables = RNNFeedables(
            # general:
            step=0,
            finished=tf.zeros([self.batch_size], dtype=tf.bool),
            input_symbol=self.go_symbols,
            prev_logits=tf.zeros([self.batch_size, len(self.vocabulary)]),
            # rnn-specific:
            prev_rnn_state=self.initial_state,
            prev_rnn_output=self.initial_state,
            prev_contexts=contexts)

        rnn_histories = RNNHistories(
            attention_histories=attn_loop_states,
            # general:
            logits=logit_ta,
            decoder_outputs=rnn_output_ta,
            outputs=outputs_ta,
            mask=mask_ta)
        # pylint: enable=not-callable

        loop_constants = DecoderConstants(train_inputs=self.train_inputs)

        return LoopState(
            histories=rnn_histories,
            constants=loop_constants,
            feedables=rnn_feedables)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号